@@ -8,22 +8,24 @@ using MQTTnet.Implementations; | |||
namespace MQTTnet.AspNetCore | |||
{ | |||
public class MqttWebSocketServerChannel : IMqttChannel, IDisposable | |||
public class MqttWebSocketServerChannel : IMqttChannel | |||
{ | |||
private WebSocket _webSocket; | |||
private readonly MqttWebSocketChannel _channel; | |||
public MqttWebSocketServerChannel(WebSocket webSocket) | |||
{ | |||
_webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); | |||
SendStream = new WebSocketStream(_webSocket); | |||
_channel = new MqttWebSocketChannel(webSocket); | |||
ReceiveStream = SendStream; | |||
} | |||
public Stream SendStream { get; private set; } | |||
public Stream ReceiveStream { get; private set; } | |||
private Stream SendStream { get; set; } | |||
private Stream ReceiveStream { get; set; } | |||
public Task ConnectAsync() | |||
public Task ConnectAsync(CancellationToken cancellationToken) | |||
{ | |||
return Task.CompletedTask; | |||
} | |||
@@ -37,7 +39,7 @@ namespace MQTTnet.AspNetCore | |||
try | |||
{ | |||
await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None); | |||
await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false); | |||
} | |||
finally | |||
{ | |||
@@ -45,6 +47,16 @@ namespace MQTTnet.AspNetCore | |||
} | |||
} | |||
public Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return ReceiveStream.ReadAsync(buffer, offset, count, cancellationToken); | |||
} | |||
public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return SendStream.WriteAsync(buffer, offset, count, cancellationToken); | |||
} | |||
public void Dispose() | |||
{ | |||
SendStream?.Dispose(); | |||
@@ -1,5 +1,4 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Packets; | |||
@@ -11,7 +10,7 @@ namespace MQTTnet.Adapter | |||
{ | |||
IMqttPacketSerializer PacketSerializer { get; } | |||
Task ConnectAsync(TimeSpan timeout); | |||
Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken); | |||
Task DisconnectAsync(TimeSpan timeout); | |||
@@ -32,12 +32,12 @@ namespace MQTTnet.Adapter | |||
public IMqttPacketSerializer PacketSerializer { get; } | |||
public Task ConnectAsync(TimeSpan timeout) | |||
public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
ThrowIfDisposed(); | |||
_logger.Verbose<MqttChannelAdapter>("Connecting [Timeout={0}]", timeout); | |||
return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync().TimeoutAfter(timeout)); | |||
return ExecuteAndWrapExceptionAsync(() => _channel.ConnectAsync(cancellationToken).TimeoutAfter(timeout)); | |||
} | |||
public Task DisconnectAsync(TimeSpan timeout) | |||
@@ -50,44 +50,32 @@ namespace MQTTnet.Adapter | |||
public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) | |||
{ | |||
ThrowIfDisposed(); | |||
foreach (var packet in packets) | |||
{ | |||
await SendPacketsAsync(timeout, cancellationToken, packet).ConfigureAwait(false); | |||
if (packet == null) | |||
{ | |||
continue; | |||
} | |||
await SendPacketAsync(timeout, cancellationToken, packet).ConfigureAwait(false); | |||
} | |||
} | |||
private Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) | |||
private Task SendPacketAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) | |||
{ | |||
ThrowIfDisposed(); | |||
if (packet == null) | |||
return ExecuteAndWrapExceptionAsync(() => | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
return ExecuteAndWrapExceptionAsync(async () => | |||
{ | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return; | |||
} | |||
_logger.Verbose<MqttChannelAdapter>("TX >>> {0} [Timeout={1}]", packet, timeout); | |||
var packetData = PacketSerializer.Serialize(packet); | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return; | |||
} | |||
await _channel.SendStream.WriteAsync( | |||
return _channel.WriteAsync( | |||
packetData.Array, | |||
packetData.Offset, | |||
packetData.Count, | |||
cancellationToken).ConfigureAwait(false); | |||
await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); | |||
cancellationToken); | |||
}); | |||
} | |||
@@ -101,7 +89,6 @@ namespace MQTTnet.Adapter | |||
ReceivedMqttPacket receivedMqttPacket = null; | |||
try | |||
{ | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
var timeoutCts = new CancellationTokenSource(timeout); | |||
@@ -109,14 +96,14 @@ namespace MQTTnet.Adapter | |||
try | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, linkedCts.Token).ConfigureAwait(false); | |||
receivedMqttPacket = await ReceiveAsync(_channel, linkedCts.Token).ConfigureAwait(false); | |||
} | |||
catch (OperationCanceledException ex) | |||
catch (OperationCanceledException exception) | |||
{ | |||
var timedOut = linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested; | |||
if (timedOut) | |||
{ | |||
throw new MqttCommunicationTimedOutException(ex); | |||
throw new MqttCommunicationTimedOutException(exception); | |||
} | |||
else | |||
{ | |||
@@ -126,7 +113,7 @@ namespace MQTTnet.Adapter | |||
} | |||
else | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); | |||
receivedMqttPacket = await ReceiveAsync(_channel, cancellationToken).ConfigureAwait(false); | |||
} | |||
if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) | |||
@@ -151,9 +138,9 @@ namespace MQTTnet.Adapter | |||
return packet; | |||
} | |||
private static async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream, CancellationToken cancellationToken) | |||
private static async Task<ReceivedMqttPacket> ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||
{ | |||
var header = await MqttPacketReader.ReadHeaderAsync(stream, cancellationToken).ConfigureAwait(false); | |||
var header = await MqttPacketReader.ReadHeaderAsync(channel, cancellationToken).ConfigureAwait(false); | |||
if (header == null) | |||
{ | |||
return null; | |||
@@ -166,7 +153,7 @@ namespace MQTTnet.Adapter | |||
var body = header.BodyLength <= ReadBufferSize ? new MemoryStream(header.BodyLength) : new MemoryStream(); | |||
var buffer = new byte[ReadBufferSize]; | |||
var buffer = new byte[Math.Min(ReadBufferSize, header.BodyLength)]; | |||
while (body.Length < header.BodyLength) | |||
{ | |||
var bytesLeft = header.BodyLength - (int)body.Length; | |||
@@ -175,7 +162,7 @@ namespace MQTTnet.Adapter | |||
bytesLeft = buffer.Length; | |||
} | |||
var readBytesCount = await stream.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); | |||
var readBytesCount = await channel.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); | |||
// Check if the client closed the connection before sending the full body. | |||
if (readBytesCount == 0) | |||
@@ -1,15 +1,15 @@ | |||
using System; | |||
using System.IO; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
namespace MQTTnet.Channel | |||
{ | |||
public interface IMqttChannel : IDisposable | |||
{ | |||
Stream SendStream { get; } | |||
Stream ReceiveStream { get; } | |||
Task ConnectAsync(); | |||
Task ConnectAsync(CancellationToken cancellationToken); | |||
Task DisconnectAsync(); | |||
Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); | |||
Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); | |||
} | |||
} |
@@ -60,7 +60,7 @@ namespace MQTTnet.Client | |||
_adapter = _adapterFactory.CreateClientAdapter(options, _logger); | |||
_logger.Verbose<MqttClient>("Trying to connect with server."); | |||
await _adapter.ConnectAsync(_options.CommunicationTimeout).ConfigureAwait(false); | |||
await _adapter.ConnectAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token).ConfigureAwait(false); | |||
_logger.Verbose<MqttClient>("Connection with server established."); | |||
await StartReceivingPacketsAsync().ConfigureAwait(false); | |||
@@ -92,14 +92,9 @@ namespace MQTTnet.Client | |||
public async Task DisconnectAsync() | |||
{ | |||
if (!IsConnected) | |||
{ | |||
return; | |||
} | |||
try | |||
{ | |||
if (!_cancellationTokenSource.IsCancellationRequested) | |||
if (IsConnected && !_cancellationTokenSource.IsCancellationRequested) | |||
{ | |||
await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); | |||
} | |||
@@ -6,7 +6,7 @@ | |||
public int? Port { get; set; } | |||
public int BufferSize { get; set; } = 20 * 4096; | |||
public int BufferSize { get; set; } = 4096; | |||
public MqttClientTlsOptions TlsOptions { get; set; } = new MqttClientTlsOptions(); | |||
} | |||
@@ -4,6 +4,7 @@ using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Runtime.InteropServices.WindowsRuntime; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using Windows.Networking; | |||
using Windows.Networking.Sockets; | |||
@@ -17,36 +18,35 @@ namespace MQTTnet.Implementations | |||
{ | |||
// ReSharper disable once MemberCanBePrivate.Global | |||
// ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global | |||
public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. | |||
public static int BufferSize { get; set; } = 4096; // Can be changed for fine tuning by library user. | |||
private readonly int _bufferSize = BufferSize; | |||
private readonly MqttClientTcpOptions _options; | |||
private StreamSocket _socket; | |||
private Stream _readStream; | |||
private Stream _writeStream; | |||
public MqttTcpChannel(MqttClientTcpOptions options) | |||
{ | |||
_options = options ?? throw new ArgumentNullException(nameof(options)); | |||
_bufferSize = _options.BufferSize; | |||
_bufferSize = options.BufferSize; | |||
} | |||
public MqttTcpChannel(StreamSocket socket) | |||
{ | |||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | |||
CreateStreams(); | |||
} | |||
public Stream SendStream { get; private set; } | |||
public Stream ReceiveStream { get; private set; } | |||
public static Func<MqttClientTcpOptions, IEnumerable<ChainValidationResult>> CustomIgnorableServerCertificateErrorsResolver { get; set; } | |||
public async Task ConnectAsync() | |||
public async Task ConnectAsync(CancellationToken cancellationToken) | |||
{ | |||
if (_socket == null) | |||
{ | |||
_socket = new StreamSocket(); | |||
_socket.Control.NoDelay = true; | |||
} | |||
if (!_options.TlsOptions.UseTls) | |||
@@ -65,7 +65,8 @@ namespace MQTTnet.Implementations | |||
await _socket.ConnectAsync(new HostName(_options.Server), _options.GetPort().ToString(), SocketProtectionLevel.Tls12); | |||
} | |||
CreateStreams(); | |||
_readStream = _socket.InputStream.AsStreamForRead(_bufferSize); | |||
_writeStream = _socket.OutputStream.AsStreamForWrite(_bufferSize); | |||
} | |||
public Task DisconnectAsync() | |||
@@ -74,11 +75,22 @@ namespace MQTTnet.Implementations | |||
return Task.FromResult(0); | |||
} | |||
public Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return _readStream.ReadAsync(buffer, offset, count, cancellationToken); | |||
} | |||
public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
await _writeStream.WriteAsync(buffer, offset, count, cancellationToken); | |||
await _writeStream.FlushAsync(cancellationToken); | |||
} | |||
public void Dispose() | |||
{ | |||
try | |||
{ | |||
SendStream?.Dispose(); | |||
_readStream?.Dispose(); | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
@@ -88,12 +100,12 @@ namespace MQTTnet.Implementations | |||
} | |||
finally | |||
{ | |||
SendStream = null; | |||
_readStream = null; | |||
} | |||
try | |||
{ | |||
ReceiveStream?.Dispose(); | |||
_writeStream?.Dispose(); | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
@@ -103,7 +115,7 @@ namespace MQTTnet.Implementations | |||
} | |||
finally | |||
{ | |||
ReceiveStream = null; | |||
_writeStream = null; | |||
} | |||
try | |||
@@ -122,12 +134,6 @@ namespace MQTTnet.Implementations | |||
} | |||
} | |||
private void CreateStreams() | |||
{ | |||
SendStream = _socket.OutputStream.AsStreamForWrite(_bufferSize); | |||
ReceiveStream = _socket.InputStream.AsStreamForRead(_bufferSize); | |||
} | |||
private static Certificate LoadCertificate(MqttClientTcpOptions options) | |||
{ | |||
if (options.TlsOptions.Certificates == null || !options.TlsOptions.Certificates.Any()) | |||
@@ -7,6 +7,7 @@ using System.Security.Cryptography.X509Certificates; | |||
using System.Threading.Tasks; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Threading; | |||
using MQTTnet.Channel; | |||
using MQTTnet.Client; | |||
@@ -14,20 +15,10 @@ namespace MQTTnet.Implementations | |||
{ | |||
public sealed class MqttTcpChannel : IMqttChannel | |||
{ | |||
#if NET452 || NET461 || NETSTANDARD2_0 | |||
// ReSharper disable once MemberCanBePrivate.Global | |||
// ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global | |||
public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. | |||
private readonly int _bufferSize = BufferSize; | |||
#else | |||
private readonly int _bufferSize = 0; | |||
#endif | |||
private readonly MqttClientTcpOptions _options; | |||
private Socket _socket; | |||
private SslStream _sslStream; | |||
private Stream _stream; | |||
/// <summary> | |||
/// called on client sockets are created in connect | |||
@@ -35,7 +26,6 @@ namespace MQTTnet.Implementations | |||
public MqttTcpChannel(MqttClientTcpOptions options) | |||
{ | |||
_options = options ?? throw new ArgumentNullException(nameof(options)); | |||
_bufferSize = options.BufferSize; | |||
} | |||
/// <summary> | |||
@@ -45,21 +35,17 @@ namespace MQTTnet.Implementations | |||
public MqttTcpChannel(Socket socket, SslStream sslStream) | |||
{ | |||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | |||
_sslStream = sslStream; | |||
CreateStreams(); | |||
CreateStream(sslStream); | |||
} | |||
public Stream SendStream { get; private set; } | |||
public Stream ReceiveStream { get; private set; } | |||
public static Func<X509Certificate, X509Chain, SslPolicyErrors, MqttClientTcpOptions, bool> CustomCertificateValidationCallback { get; set; } | |||
public async Task ConnectAsync() | |||
public async Task ConnectAsync(CancellationToken cancellationToken) | |||
{ | |||
if (_socket == null) | |||
{ | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; | |||
} | |||
#if NET452 || NET461 | |||
@@ -68,15 +54,14 @@ namespace MQTTnet.Implementations | |||
await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); | |||
#endif | |||
_socket.NoDelay = true; | |||
SslStream sslStream = null; | |||
if (_options.TlsOptions.UseTls) | |||
{ | |||
_sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); | |||
await _sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); | |||
sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); | |||
await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); | |||
} | |||
CreateStreams(); | |||
CreateStream(sslStream); | |||
} | |||
public Task DisconnectAsync() | |||
@@ -85,46 +70,21 @@ namespace MQTTnet.Implementations | |||
return Task.FromResult(0); | |||
} | |||
public void Dispose() | |||
public Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
var oneStreamIsUsed = SendStream != null && ReceiveStream != null && ReferenceEquals(SendStream, ReceiveStream); | |||
try | |||
{ | |||
SendStream?.Dispose(); | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
} | |||
catch (NullReferenceException) | |||
{ | |||
} | |||
finally | |||
{ | |||
SendStream = null; | |||
} | |||
return _stream.ReadAsync(buffer, offset, count, cancellationToken); | |||
} | |||
try | |||
{ | |||
if (!oneStreamIsUsed) | |||
{ | |||
ReceiveStream?.Dispose(); | |||
} | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
} | |||
catch (NullReferenceException) | |||
{ | |||
} | |||
finally | |||
{ | |||
ReceiveStream = null; | |||
} | |||
public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return _stream.WriteAsync(buffer, offset, count, cancellationToken); | |||
} | |||
public void Dispose() | |||
{ | |||
try | |||
{ | |||
_sslStream?.Dispose(); | |||
_stream?.Dispose(); | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
@@ -134,7 +94,7 @@ namespace MQTTnet.Implementations | |||
} | |||
finally | |||
{ | |||
_sslStream = null; | |||
_stream = null; | |||
} | |||
try | |||
@@ -200,25 +160,16 @@ namespace MQTTnet.Implementations | |||
return certificates; | |||
} | |||
private void CreateStreams() | |||
private void CreateStream(Stream stream) | |||
{ | |||
Stream stream; | |||
if (_sslStream != null) | |||
if (stream != null) | |||
{ | |||
stream = _sslStream; | |||
_stream = stream; | |||
} | |||
else | |||
{ | |||
stream = new NetworkStream(_socket, true); | |||
_stream = new NetworkStream(_socket, true); | |||
} | |||
#if NET452 || NET461 || NETSTANDARD2_0 | |||
SendStream = new BufferedStream(stream, _bufferSize); | |||
ReceiveStream = new BufferedStream(stream, _bufferSize); | |||
#else | |||
SendStream = stream; | |||
ReceiveStream = stream; | |||
#endif | |||
} | |||
} | |||
} | |||
@@ -62,6 +62,8 @@ namespace MQTTnet.Implementations | |||
{ | |||
try | |||
{ | |||
args.Socket.Control.NoDelay = true; | |||
var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(args.Socket), new MqttPacketSerializer(), _logger); | |||
ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); | |||
} | |||
@@ -43,7 +43,7 @@ namespace MQTTnet.Implementations | |||
_defaultEndpointSocket.Bind(new IPEndPoint(options.DefaultEndpointOptions.BoundIPAddress, options.GetDefaultEndpointPort())); | |||
_defaultEndpointSocket.Listen(options.ConnectionBacklog); | |||
Task.Run(async () => await AcceptDefaultEndpointConnectionsAsync(_cancellationTokenSource.Token).ConfigureAwait(false), _cancellationTokenSource.Token).ConfigureAwait(false); | |||
Task.Run(() => AcceptDefaultEndpointConnectionsAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); | |||
} | |||
if (options.TlsEndpointOptions.IsEnabled) | |||
@@ -63,7 +63,7 @@ namespace MQTTnet.Implementations | |||
_tlsEndpointSocket.Bind(new IPEndPoint(options.TlsEndpointOptions.BoundIPAddress, options.GetTlsEndpointPort())); | |||
_tlsEndpointSocket.Listen(options.ConnectionBacklog); | |||
Task.Run(async () => await AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token).ConfigureAwait(false), _cancellationTokenSource.Token).ConfigureAwait(false); | |||
Task.Run(() => AcceptTlsEndpointConnectionsAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); | |||
} | |||
return Task.FromResult(0); | |||
@@ -1,11 +1,13 @@ | |||
using System; | |||
using System.IO; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Net.WebSockets; | |||
using System.Security.Cryptography.X509Certificates; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Channel; | |||
using MQTTnet.Client; | |||
using MQTTnet.Exceptions; | |||
namespace MQTTnet.Implementations | |||
{ | |||
@@ -13,20 +15,25 @@ namespace MQTTnet.Implementations | |||
{ | |||
// ReSharper disable once MemberCanBePrivate.Global | |||
// ReSharper disable once AutoPropertyCanBeMadeGetOnly.Global | |||
public static int BufferSize { get; set; } = 4096 * 20; // Can be changed for fine tuning by library user. | |||
public static int BufferSize { get; set; } = 4096; // Can be changed for fine tuning by library user. | |||
private readonly byte[] _chunckBuffer = new byte[BufferSize]; | |||
private readonly Queue<byte> _buffer = new Queue<byte>(BufferSize); | |||
private readonly MqttClientWebSocketOptions _options; | |||
private ClientWebSocket _webSocket; | |||
private WebSocket _webSocket; | |||
public MqttWebSocketChannel(MqttClientWebSocketOptions options) | |||
{ | |||
_options = options ?? throw new ArgumentNullException(nameof(options)); | |||
} | |||
public Stream SendStream { get; private set; } | |||
public Stream ReceiveStream { get; private set; } | |||
public MqttWebSocketChannel(WebSocket webSocket) | |||
{ | |||
_webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); | |||
} | |||
public async Task ConnectAsync() | |||
public async Task ConnectAsync(CancellationToken cancellationToken) | |||
{ | |||
var uri = _options.Uri; | |||
if (!uri.StartsWith("ws://", StringComparison.OrdinalIgnoreCase) && !uri.StartsWith("wss://", StringComparison.OrdinalIgnoreCase)) | |||
@@ -41,13 +48,13 @@ namespace MQTTnet.Implementations | |||
} | |||
} | |||
_webSocket = new ClientWebSocket(); | |||
var clientWebSocket = new ClientWebSocket(); | |||
if (_options.RequestHeaders != null) | |||
{ | |||
foreach (var requestHeader in _options.RequestHeaders) | |||
{ | |||
_webSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); | |||
clientWebSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); | |||
} | |||
} | |||
@@ -55,28 +62,26 @@ namespace MQTTnet.Implementations | |||
{ | |||
foreach (var subProtocol in _options.SubProtocols) | |||
{ | |||
_webSocket.Options.AddSubProtocol(subProtocol); | |||
clientWebSocket.Options.AddSubProtocol(subProtocol); | |||
} | |||
} | |||
if (_options.CookieContainer != null) | |||
{ | |||
_webSocket.Options.Cookies = _options.CookieContainer; | |||
clientWebSocket.Options.Cookies = _options.CookieContainer; | |||
} | |||
if (_options.TlsOptions?.UseTls == true && _options.TlsOptions?.Certificates != null) | |||
{ | |||
_webSocket.Options.ClientCertificates = new X509CertificateCollection(); | |||
clientWebSocket.Options.ClientCertificates = new X509CertificateCollection(); | |||
foreach (var certificate in _options.TlsOptions.Certificates) | |||
{ | |||
_webSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); | |||
clientWebSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); | |||
} | |||
} | |||
await _webSocket.ConnectAsync(new Uri(uri), CancellationToken.None).ConfigureAwait(false); | |||
SendStream = new WebSocketStream(_webSocket); | |||
ReceiveStream = SendStream; | |||
await clientWebSocket.ConnectAsync(new Uri(uri), cancellationToken).ConfigureAwait(false); | |||
_webSocket = clientWebSocket; | |||
} | |||
public async Task DisconnectAsync() | |||
@@ -94,6 +99,56 @@ namespace MQTTnet.Implementations | |||
Dispose(); | |||
} | |||
public async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
var bytesRead = 0; | |||
// Use existing date from buffer. | |||
while (count > 0 && _buffer.Any()) | |||
{ | |||
buffer[offset] = _buffer.Dequeue(); | |||
count--; | |||
bytesRead++; | |||
offset++; | |||
} | |||
if (count == 0) | |||
{ | |||
return bytesRead; | |||
} | |||
// Fetch new data if the buffer is not full. | |||
while (_webSocket.State == WebSocketState.Open) | |||
{ | |||
await FetchChunkAsync(cancellationToken).ConfigureAwait(false); | |||
while (count > 0 && _buffer.Any()) | |||
{ | |||
buffer[offset] = _buffer.Dequeue(); | |||
count--; | |||
bytesRead++; | |||
offset++; | |||
} | |||
if (count == 0) | |||
{ | |||
return bytesRead; | |||
} | |||
} | |||
if (_webSocket.State == WebSocketState.Closed) | |||
{ | |||
throw new MqttCommunicationException("WebSocket connection closed."); | |||
} | |||
return bytesRead; | |||
} | |||
public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return _webSocket.SendAsync(new ArraySegment<byte>(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); | |||
} | |||
public void Dispose() | |||
{ | |||
try | |||
@@ -108,5 +163,26 @@ namespace MQTTnet.Implementations | |||
_webSocket = null; | |||
} | |||
} | |||
private async Task FetchChunkAsync(CancellationToken cancellationToken) | |||
{ | |||
var response = await _webSocket.ReceiveAsync(new ArraySegment<byte>(_chunckBuffer, 0, _chunckBuffer.Length), cancellationToken).ConfigureAwait(false); | |||
if (response.MessageType == WebSocketMessageType.Close) | |||
{ | |||
await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); | |||
} | |||
else if (response.MessageType == WebSocketMessageType.Binary) | |||
{ | |||
for (var i = 0; i < response.Count; i++) | |||
{ | |||
_buffer.Enqueue(_chunckBuffer[i]); | |||
} | |||
} | |||
else if (response.MessageType == WebSocketMessageType.Text) | |||
{ | |||
throw new MqttProtocolViolationException("WebSocket channel received TEXT message."); | |||
} | |||
} | |||
} | |||
} |
@@ -1,137 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Net.WebSockets; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Exceptions; | |||
namespace MQTTnet.Implementations | |||
{ | |||
public class WebSocketStream : Stream | |||
{ | |||
private readonly byte[] _chunckBuffer = new byte[MqttWebSocketChannel.BufferSize]; | |||
private readonly Queue<byte> _buffer = new Queue<byte>(MqttWebSocketChannel.BufferSize); | |||
private readonly WebSocket _webSocket; | |||
public WebSocketStream(WebSocket webSocket) | |||
{ | |||
_webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); | |||
} | |||
public override bool CanRead => true; | |||
public override bool CanSeek => false; | |||
public override bool CanWrite => true; | |||
public override long Length => throw new NotSupportedException(); | |||
public override long Position | |||
{ | |||
get => throw new NotSupportedException(); | |||
set => throw new NotSupportedException(); | |||
} | |||
public override void Flush() | |||
{ | |||
} | |||
public override Task FlushAsync(CancellationToken cancellationToken) | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
public override int Read(byte[] buffer, int offset, int count) | |||
{ | |||
return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); | |||
} | |||
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
var bytesRead = 0; | |||
// Use existing date from buffer. | |||
while (count > 0 && _buffer.Any()) | |||
{ | |||
buffer[offset] = _buffer.Dequeue(); | |||
count--; | |||
bytesRead++; | |||
offset++; | |||
} | |||
if (count == 0) | |||
{ | |||
return bytesRead; | |||
} | |||
// Fetch new data if the buffer is not full. | |||
while (_webSocket.State == WebSocketState.Open) | |||
{ | |||
await FetchChunkAsync(cancellationToken).ConfigureAwait(false); | |||
while (count > 0 && _buffer.Any()) | |||
{ | |||
buffer[offset] = _buffer.Dequeue(); | |||
count--; | |||
bytesRead++; | |||
offset++; | |||
} | |||
if (count == 0) | |||
{ | |||
return bytesRead; | |||
} | |||
} | |||
if (_webSocket.State == WebSocketState.Closed) | |||
{ | |||
throw new MqttCommunicationException("WebSocket connection closed."); | |||
} | |||
return bytesRead; | |||
} | |||
public override void Write(byte[] buffer, int offset, int count) | |||
{ | |||
WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); | |||
} | |||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return _webSocket.SendAsync(new ArraySegment<byte>(buffer, offset, count), WebSocketMessageType.Binary, true, cancellationToken); | |||
} | |||
public override long Seek(long offset, SeekOrigin origin) | |||
{ | |||
throw new NotSupportedException(); | |||
} | |||
public override void SetLength(long value) | |||
{ | |||
throw new NotSupportedException(); | |||
} | |||
private async Task FetchChunkAsync(CancellationToken cancellationToken) | |||
{ | |||
var response = await _webSocket.ReceiveAsync(new ArraySegment<byte>(_chunckBuffer, 0, _chunckBuffer.Length), cancellationToken).ConfigureAwait(false); | |||
if (response.MessageType == WebSocketMessageType.Close) | |||
{ | |||
await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); | |||
} | |||
else if (response.MessageType == WebSocketMessageType.Binary) | |||
{ | |||
for (var i = 0; i < response.Count; i++) | |||
{ | |||
_buffer.Enqueue(_chunckBuffer[i]); | |||
} | |||
} | |||
else if (response.MessageType == WebSocketMessageType.Text) | |||
{ | |||
throw new MqttProtocolViolationException("WebSocket channel received TEXT message."); | |||
} | |||
} | |||
} | |||
} |
@@ -1,5 +1,4 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using MQTTnet.Packets; | |||
@@ -11,6 +10,6 @@ namespace MQTTnet.Serializer | |||
ArraySegment<byte> Serialize(MqttBasePacket mqttPacket); | |||
MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); | |||
MqttBasePacket Deserialize(MqttPacketHeader header, Stream body); | |||
} | |||
} |
@@ -1,9 +1,9 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Channel; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Packets; | |||
using MQTTnet.Protocol; | |||
@@ -13,7 +13,7 @@ namespace MQTTnet.Serializer | |||
public sealed class MqttPacketReader : BinaryReader | |||
{ | |||
private readonly MqttPacketHeader _header; | |||
public MqttPacketReader(MqttPacketHeader header, Stream bodyStream) | |||
: base(bodyStream, Encoding.UTF8, true) | |||
{ | |||
@@ -22,7 +22,7 @@ namespace MQTTnet.Serializer | |||
public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; | |||
public static async Task<MqttPacketHeader> ReadHeaderAsync(Stream stream, CancellationToken cancellationToken) | |||
public static async Task<MqttPacketHeader> ReadHeaderAsync(IMqttChannel stream, CancellationToken cancellationToken) | |||
{ | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
@@ -33,7 +33,7 @@ namespace MQTTnet.Serializer | |||
// some large delay and thus the thread should be put back to the pool (await). So ReadByte() | |||
// is not an option here. | |||
var buffer = new byte[1]; | |||
var readCount = await stream.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); | |||
var readCount = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); | |||
if (readCount <= 0) | |||
{ | |||
return null; | |||
@@ -89,15 +89,14 @@ namespace MQTTnet.Serializer | |||
return ReadBytes(_header.BodyLength - (int)BaseStream.Position); | |||
} | |||
private static async Task<int> ReadBodyLengthAsync(Stream stream, CancellationToken cancellationToken) | |||
private static async Task<int> ReadBodyLengthAsync(IMqttChannel stream, CancellationToken cancellationToken) | |||
{ | |||
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | |||
var multiplier = 1; | |||
var value = 0; | |||
byte encodedByte; | |||
int encodedByte; | |||
var buffer = new byte[1]; | |||
var readBytes = new List<byte>(); | |||
do | |||
{ | |||
if (cancellationToken.IsCancellationRequested) | |||
@@ -112,12 +111,11 @@ namespace MQTTnet.Serializer | |||
} | |||
encodedByte = buffer[0]; | |||
readBytes.Add(encodedByte); | |||
value += (byte)(encodedByte & 127) * multiplier; | |||
if (multiplier > 128 * 128 * 128) | |||
{ | |||
throw new MqttProtocolViolationException($"Remaining length is invalid (Data={string.Join(",", readBytes)})."); | |||
throw new MqttProtocolViolationException("Remaining length is invalid."); | |||
} | |||
multiplier *= 128; | |||
@@ -27,7 +27,7 @@ namespace MQTTnet.Serializer | |||
var fixedHeader = SerializePacket(packet, writer); | |||
var remainingLength = MqttPacketWriter.GetRemainingLength((int)stream.Length - 5); | |||
var remainingLength = MqttPacketWriter.EncodeRemainingLength((int)stream.Length - 5); | |||
var headerSize = remainingLength.Length + 1; | |||
var headerOffset = 5 - headerSize; | |||
@@ -47,7 +47,7 @@ namespace MQTTnet.Serializer | |||
} | |||
} | |||
public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) | |||
public MqttBasePacket Deserialize(MqttPacketHeader header, Stream body) | |||
{ | |||
if (header == null) throw new ArgumentNullException(nameof(header)); | |||
if (body == null) throw new ArgumentNullException(nameof(body)); | |||
@@ -183,7 +183,7 @@ namespace MQTTnet.Serializer | |||
var topic = reader.ReadStringWithLengthPrefix(); | |||
ushort packetIdentifier = 0; | |||
ushort? packetIdentifier = null; | |||
if (qualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) | |||
{ | |||
packetIdentifier = reader.ReadUInt16(); | |||
@@ -191,12 +191,12 @@ namespace MQTTnet.Serializer | |||
var packet = new MqttPublishPacket | |||
{ | |||
PacketIdentifier = packetIdentifier, | |||
Retain = retain, | |||
QualityOfServiceLevel = qualityOfServiceLevel, | |||
Dup = dup, | |||
Topic = topic, | |||
Payload = reader.ReadRemainingData(), | |||
PacketIdentifier = packetIdentifier | |||
QualityOfServiceLevel = qualityOfServiceLevel, | |||
Dup = dup | |||
}; | |||
return packet; | |||
@@ -56,7 +56,7 @@ namespace MQTTnet.Serializer | |||
Write(value); | |||
} | |||
public static byte[] GetRemainingLength(int length) | |||
public static byte[] EncodeRemainingLength(int length) | |||
{ | |||
if (length <= 0) | |||
{ | |||
@@ -82,7 +82,8 @@ namespace MQTTnet.Serializer | |||
offset++; | |||
} while (x > 0); | |||
return bytes.Take(offset).ToArray(); | |||
Array.Resize(ref bytes, offset); | |||
return bytes; | |||
} | |||
} | |||
} |
@@ -18,7 +18,7 @@ MQTTnet is a high performance .NET library for MQTT based communication. It prov | |||
* TLS 1.2 support for client and server (but not UWP servers) | |||
* Extensible communication channels (i.e. In-Memory, TCP, TCP+TLS, WS) | |||
* Lightweight (only the low level implementation of MQTT, no overhead) | |||
* Performance optimized (processing ~40.000 messages / second)* | |||
* Performance optimized (processing ~50.000 messages / second)* | |||
* Interfaces included for mocking and testing | |||
* Access to internal trace messages | |||
* Unit tested (~80 tests) | |||
@@ -11,8 +11,7 @@ namespace MQTTnet.Core.Tests | |||
[TestMethod] | |||
public void MqttPacketReader_EmptyStream() | |||
{ | |||
var memStream = new MemoryStream(); | |||
var header = MqttPacketReader.ReadHeaderAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); | |||
var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); | |||
Assert.IsNull(header); | |||
} | |||
@@ -416,7 +416,7 @@ namespace MQTTnet.Core.Tests | |||
using (var headerStream = new MemoryStream(Join(buffer1))) | |||
{ | |||
var header = MqttPacketReader.ReadHeaderAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); | |||
var header = MqttPacketReader.ReadHeaderAsync(new TestMqttChannel(headerStream), CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.BodyLength)) | |||
{ | |||
@@ -0,0 +1,41 @@ | |||
using System.IO; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Channel; | |||
namespace MQTTnet.Core.Tests | |||
{ | |||
public class TestMqttChannel : IMqttChannel | |||
{ | |||
private readonly MemoryStream _stream; | |||
public TestMqttChannel(MemoryStream stream) | |||
{ | |||
_stream = stream; | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
public Task ConnectAsync(CancellationToken cancellationToken) | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
public Task DisconnectAsync() | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
public Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return _stream.ReadAsync(buffer, offset, count, cancellationToken); | |||
} | |||
public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||
{ | |||
return _stream.WriteAsync(buffer, offset, count, cancellationToken); | |||
} | |||
} | |||
} |
@@ -1,6 +1,5 @@ | |||
using System; | |||
using System.Collections.Concurrent; | |||
using System.Collections.Generic; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Adapter; | |||
@@ -21,7 +20,7 @@ namespace MQTTnet.Core.Tests | |||
{ | |||
} | |||
public Task ConnectAsync(TimeSpan timeout) | |||
public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
@@ -59,7 +59,7 @@ namespace MQTTnet.TestApp.NetCore | |||
var sentMessagesCount = 0; | |||
while (stopwatch.ElapsedMilliseconds < 1000) | |||
{ | |||
await client.PublishAsync(messages).ConfigureAwait(false); | |||
client.PublishAsync(messages).GetAwaiter().GetResult(); | |||
sentMessagesCount++; | |||
} | |||