diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index 6febcba..fb5f4f1 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -7,6 +7,7 @@ using MQTTnet.Formatter; using MQTTnet.Packets; using System; using System.IO.Pipelines; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -48,6 +49,8 @@ namespace MQTTnet.AspNetCore public bool IsSecureConnection => Http?.HttpContext?.Request?.IsHttps ?? false; + public X509Certificate2 ClientCertificate => Http?.HttpContext?.Connection?.ClientCertificate; + private IHttpContextFeature Http => Connection.Features.Get(); public ConnectionContext Connection { get; } diff --git a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs b/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs index 455840c..a052dae 100644 --- a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs +++ b/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs @@ -40,19 +40,26 @@ namespace MQTTnet.AspNetCore var endpoint = $"{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}"; var clientCertificate = await httpContext.Connection.GetClientCertificateAsync().ConfigureAwait(false); - var isSecureConnection = clientCertificate != null; - clientCertificate?.Dispose(); - - var clientHandler = ClientHandler; - if (clientHandler != null) + try { - var writer = new SpanBasedMqttPacketWriter(); - var formatter = new MqttPacketFormatterAdapter(writer); - var channel = new MqttWebSocketChannel(webSocket, endpoint, isSecureConnection); - using (var channelAdapter = new MqttChannelAdapter(channel, formatter, _logger.CreateChildLogger(nameof(MqttWebSocketServerAdapter)))) + var isSecureConnection = clientCertificate != null; + + var clientHandler = ClientHandler; + if (clientHandler != null) { - await clientHandler(channelAdapter).ConfigureAwait(false); - } + var writer = new SpanBasedMqttPacketWriter(); + var formatter = new MqttPacketFormatterAdapter(writer); + var channel = new MqttWebSocketChannel(webSocket, endpoint, isSecureConnection, clientCertificate); + + using (var channelAdapter = new MqttChannelAdapter(channel, formatter, _logger.CreateChildLogger(nameof(MqttWebSocketServerAdapter)))) + { + await clientHandler(channelAdapter).ConfigureAwait(false); + } + } + } + finally + { + clientCertificate?.Dispose(); } } diff --git a/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs b/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs index c3c6be2..68b12ef 100644 --- a/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs +++ b/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs @@ -35,6 +35,8 @@ namespace MQTTnet.Extensions.WebSocket4Net public bool IsSecureConnection { get; private set; } + public X509Certificate2 ClientCertificate { get; } + public async Task ConnectAsync(CancellationToken cancellationToken) { var uri = _webSocketOptions.Uri; diff --git a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs index 08bc809..118761f 100644 --- a/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/IMqttChannelAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using MQTTnet.Formatter; @@ -12,6 +13,8 @@ namespace MQTTnet.Adapter bool IsSecureConnection { get; } + X509Certificate2 ClientCertificate { get; } + MqttPacketFormatterAdapter PacketFormatterAdapter { get; } long BytesSent { get; } diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index a4d71b3..4a0a85d 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -2,6 +2,7 @@ using System; using System.IO; using System.Net.Sockets; using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; @@ -47,6 +48,8 @@ namespace MQTTnet.Adapter public bool IsSecureConnection => _channel.IsSecureConnection; + public X509Certificate2 ClientCertificate => _channel.ClientCertificate; + public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } public long BytesSent => Interlocked.Read(ref _bytesSent); diff --git a/Source/MQTTnet/Channel/IMqttChannel.cs b/Source/MQTTnet/Channel/IMqttChannel.cs index 4848b46..188e55f 100644 --- a/Source/MQTTnet/Channel/IMqttChannel.cs +++ b/Source/MQTTnet/Channel/IMqttChannel.cs @@ -1,4 +1,5 @@ using System; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -8,6 +9,7 @@ namespace MQTTnet.Channel { string Endpoint { get; } bool IsSecureConnection { get; } + X509Certificate2 ClientCertificate { get; } Task ConnectAsync(CancellationToken cancellationToken); Task DisconnectAsync(CancellationToken cancellationToken); diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs index 8ced03a..a7aefd1 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs @@ -156,6 +156,7 @@ namespace MQTTnet.Client.Options return this; } + // TODO: Consider creating _MqttClientTcpOptionsBuilder_ as overload. public MqttClientOptionsBuilder WithTcpServer(Action optionsBuilder) { if (optionsBuilder == null) throw new ArgumentNullException(nameof(optionsBuilder)); diff --git a/Source/MQTTnet/Client/Options/MqttClientTcpOptions.cs b/Source/MQTTnet/Client/Options/MqttClientTcpOptions.cs index c7b0e80..63a9e02 100644 --- a/Source/MQTTnet/Client/Options/MqttClientTcpOptions.cs +++ b/Source/MQTTnet/Client/Options/MqttClientTcpOptions.cs @@ -10,7 +10,7 @@ namespace MQTTnet.Client.Options public int BufferSize { get; set; } = 65536; - public bool DualMode { get; set; } = true; + public bool? DualMode { get; set; } public bool NoDelay { get; set; } = true; diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs index 1c0f944..cb0f71c 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices.WindowsRuntime; using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using Windows.Networking; @@ -31,7 +32,7 @@ namespace MQTTnet.Implementations _bufferSize = _options.BufferSize; } - public MqttTcpChannel(StreamSocket socket, IMqttServerOptions serverOptions) + public MqttTcpChannel(StreamSocket socket, X509Certificate2 clientCertificate, IMqttServerOptions serverOptions) { _socket = socket ?? throw new ArgumentNullException(nameof(socket)); _bufferSize = serverOptions.DefaultEndpointOptions.BufferSize; @@ -39,6 +40,7 @@ namespace MQTTnet.Implementations CreateStreams(); IsSecureConnection = socket.Information.ProtectionLevel >= SocketProtectionLevel.Tls12; + ClientCertificate = clientCertificate; Endpoint = _socket.Information.RemoteAddress + ":" + _socket.Information.RemotePort; } @@ -49,6 +51,8 @@ namespace MQTTnet.Implementations public bool IsSecureConnection { get; } + public X509Certificate2 ClientCertificate { get; } + public async Task ConnectAsync(CancellationToken cancellationToken) { if (_socket == null) diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index cd4e290..63adf55 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -28,18 +28,22 @@ namespace MQTTnet.Implementations IsSecureConnection = clientOptions.ChannelOptions?.TlsOptions?.UseTls == true; } - public MqttTcpChannel(Stream stream, string endpoint) + public MqttTcpChannel(Stream stream, string endpoint, X509Certificate2 clientCertificate) { _stream = stream ?? throw new ArgumentNullException(nameof(stream)); - IsSecureConnection = stream is SslStream; Endpoint = endpoint; + + IsSecureConnection = stream is SslStream; + ClientCertificate = clientCertificate; } public string Endpoint { get; private set; } public bool IsSecureConnection { get; } + public X509Certificate2 ClientCertificate { get; } + public async Task ConnectAsync(CancellationToken cancellationToken) { Socket socket; @@ -55,9 +59,16 @@ namespace MQTTnet.Implementations socket.ReceiveBufferSize = _options.BufferSize; socket.SendBufferSize = _options.BufferSize; - socket.DualMode = _options.DualMode; socket.NoDelay = _options.NoDelay; + if (_options.DualMode.HasValue) + { + // It is important to avoid setting the flag if no specific value is set by the user + // because on IPv4 only networks the setter will always throw an exception. Regardless + // of the actual value. + socket.DualMode = _options.DualMode.Value; + } + // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 using (cancellationToken.Register(() => socket.Dispose())) { diff --git a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.Uwp.cs b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.Uwp.cs index 4529075..3b24bd1 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.Uwp.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.Uwp.cs @@ -1,11 +1,13 @@ #if WINDOWS_UWP -using System; -using System.Threading.Tasks; using Windows.Networking.Sockets; using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Formatter; using MQTTnet.Server; +using System; +using System.Runtime.InteropServices.WindowsRuntime; +using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; namespace MQTTnet.Implementations { @@ -73,7 +75,21 @@ namespace MQTTnet.Implementations var clientHandler = ClientHandler; if (clientHandler != null) { - using (var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(args.Socket, _options), new MqttPacketFormatterAdapter(), _logger)) + X509Certificate2 clientCertificate = null; + + if (args.Socket.Control.ClientCertificate != null) + { + try + { + clientCertificate = new X509Certificate2(args.Socket.Control.ClientCertificate.GetCertificateBlob().ToArray()); + } + catch (Exception exception) + { + _logger.Warning(exception, "Unable to convert UWP certificate to X509Certificate2."); + } + } + + using (var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(args.Socket, clientCertificate, _options), new MqttPacketFormatterAdapter(), _logger)) { await clientHandler(clientAdapter).ConfigureAwait(false); } diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index cc8246a..fe03c73 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -112,17 +112,27 @@ namespace MQTTnet.Implementations stream = new NetworkStream(clientSocket, true); + X509Certificate2 clientCertificate = null; + if (_tlsCertificate != null) { var sslStream = new SslStream(stream, false); - await sslStream.AuthenticateAsServerAsync(_tlsCertificate, false, _tlsOptions.SslProtocol, false).ConfigureAwait(false); + + await sslStream.AuthenticateAsServerAsync( + _tlsCertificate, + _tlsOptions.ClientCertificateRequired, + _tlsOptions.SslProtocol, + _tlsOptions.CheckCertificateRevocation).ConfigureAwait(false); + stream = sslStream; + + clientCertificate = sslStream.RemoteCertificate as X509Certificate2; } var clientHandler = ClientHandler; if (clientHandler != null) { - using (var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(stream, remoteEndPoint), new MqttPacketFormatterAdapter(), _logger)) + using (var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(stream, remoteEndPoint, clientCertificate), new MqttPacketFormatterAdapter(), _logger)) { await clientHandler(clientAdapter).ConfigureAwait(false); } diff --git a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs index ccc601a..38e4342 100644 --- a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs +++ b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs @@ -11,9 +11,9 @@ namespace MQTTnet.Implementations { public class MqttWebSocketChannel : IMqttChannel { - private readonly SemaphoreSlim _sendLock = new SemaphoreSlim(1, 1); private readonly MqttClientWebSocketOptions _options; + private SemaphoreSlim _sendLock = new SemaphoreSlim(1, 1); private WebSocket _webSocket; public MqttWebSocketChannel(MqttClientWebSocketOptions options) @@ -21,18 +21,21 @@ namespace MQTTnet.Implementations _options = options ?? throw new ArgumentNullException(nameof(options)); } - public MqttWebSocketChannel(WebSocket webSocket, string endpoint, bool isSecureConnection) + public MqttWebSocketChannel(WebSocket webSocket, string endpoint, bool isSecureConnection, X509Certificate2 clientCertificate) { _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); Endpoint = endpoint; IsSecureConnection = isSecureConnection; + ClientCertificate = clientCertificate; } public string Endpoint { get; } public bool IsSecureConnection { get; private set; } + public X509Certificate2 ClientCertificate { get; private set; } + public async Task ConnectAsync(CancellationToken cancellationToken) { var uri = _options.Uri; @@ -114,9 +117,14 @@ namespace MQTTnet.Implementations public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - // This lock is required because the client will throw an exception if _SendAsync_ is + // The lock is required because the client will throw an exception if _SendAsync_ is // called from multiple threads at the same time. But this issue only happens with several // framework versions. + if (_sendLock == null) + { + return; + } + await _sendLock.WaitAsync(cancellationToken).ConfigureAwait(false); try { @@ -124,13 +132,14 @@ namespace MQTTnet.Implementations } finally { - _sendLock.Release(); + _sendLock?.Release(); } } public void Dispose() { _sendLock?.Dispose(); + _sendLock = null; try { diff --git a/Source/MQTTnet/Internal/TestMqttChannel.cs b/Source/MQTTnet/Internal/TestMqttChannel.cs index 1ef8e5a..954aa1b 100644 --- a/Source/MQTTnet/Internal/TestMqttChannel.cs +++ b/Source/MQTTnet/Internal/TestMqttChannel.cs @@ -1,4 +1,5 @@ using System.IO; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; @@ -18,6 +19,8 @@ namespace MQTTnet.Internal public bool IsSecureConnection { get; } = false; + public X509Certificate2 ClientCertificate { get; } + public Task ConnectAsync(CancellationToken cancellationToken) { return Task.FromResult(0); diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index fdc62b0..e889880 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -136,7 +136,7 @@ namespace MQTTnet.Server await connection.StopAsync().ConfigureAwait(false); } - if (_sessions.TryRemove(clientId, out var session)) + if (_sessions.TryRemove(clientId, out _)) { } @@ -296,7 +296,8 @@ namespace MQTTnet.Server connectPacket.Password, connectPacket.WillMessage, clientAdapter.Endpoint, - clientAdapter.IsSecureConnection); + clientAdapter.IsSecureConnection, + clientAdapter.ClientCertificate); var connectionValidator = _options.ConnectionValidator; diff --git a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs index a00cffc..d42a65c 100644 --- a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs +++ b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs @@ -1,4 +1,5 @@ -using System.Text; +using System.Security.Cryptography.X509Certificates; +using System.Text; using MQTTnet.Protocol; namespace MQTTnet.Server @@ -11,7 +12,8 @@ namespace MQTTnet.Server byte[] password, MqttApplicationMessage willMessage, string endpoint, - bool isSecureConnection) + bool isSecureConnection, + X509Certificate2 clientCertificate) { ClientId = clientId; Username = username; @@ -19,6 +21,7 @@ namespace MQTTnet.Server WillMessage = willMessage; Endpoint = endpoint; IsSecureConnection = isSecureConnection; + ClientCertificate = clientCertificate; } public string ClientId { get; } @@ -35,6 +38,8 @@ namespace MQTTnet.Server public bool IsSecureConnection { get; } + public X509Certificate2 ClientCertificate { get; } + public MqttConnectReturnCode ReturnCode { get; set; } = MqttConnectReturnCode.ConnectionAccepted; } } diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs index c8b3f7c..7fbb552 100644 --- a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -1,6 +1,7 @@ using BenchmarkDotNet.Attributes; using MQTTnet.Packets; using System; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -77,6 +78,8 @@ namespace MQTTnet.Benchmarks public bool IsSecureConnection { get; } = false; + public X509Certificate2 ClientCertificate { get; } + public void Reset() { _position = _buffer.Offset; diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs index a58a77f..9aa99e6 100644 --- a/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestMqttCommunicationAdapter.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Concurrent; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -18,6 +19,8 @@ namespace MQTTnet.Tests.Mockups public bool IsSecureConnection { get; } = false; + public X509Certificate2 ClientCertificate { get; } + public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } = new MqttPacketFormatterAdapter(MqttProtocolVersion.V311); public long BytesSent { get; } diff --git a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs index b2005ba..436d2d1 100644 --- a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs @@ -39,7 +39,7 @@ namespace MQTTnet.Tests await Task.Delay(100, ct.Token); - var tcpChannel = new MqttTcpChannel(new NetworkStream(clientSocket, true), "test"); + var tcpChannel = new MqttTcpChannel(new NetworkStream(clientSocket, true), "test", null); var buffer = new byte[1]; await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token);