From 59f07c868dfb56b4c5f6ffa5cd1764274e186b92 Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Mon, 9 Oct 2017 19:03:11 +0200 Subject: [PATCH] Add custom certificate validation handler --- Build/MQTTnet.nuspec | 3 +- .../Implementations/MqttTcpChannel.cs | 11 ++++- .../Implementations/MqttTcpChannel.cs | 13 +++-- .../Implementations/MqttTcpChannel.cs | 48 +++++++++++++------ Tests/MQTTnet.Core.Tests/MqttServerTests.cs | 3 +- 5 files changed, 57 insertions(+), 21 deletions(-) diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 5112cd5..982d205 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -17,9 +17,10 @@ * [Server] Fixed handling of _Dup_ flag (Thanks to haeberle) * [Core] Optimized exception handling * [Core] Mono is now also supported (Thanks to JTrotta) -* [Client] The options are now passed in _ConnectAsync_ (Breaking change!) +* [Client] The options are now passed in _ConnectAsync_ (Breaking change! Read Wiki for examples) * [Core] Trace class renamed to _MqttNetTrace_ (Breaking change!) * [Client] Extended certificate validation options (Breaking change!) +* [Client] Added static certificate validation callback (NetFramework, NetStandard) / ignorable certificate errors (UniversalWindows) to _MqttTcpChannel_ Copyright Christian Kratky 2016-2017 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs index 7f32de8..6f9375e 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs @@ -46,6 +46,8 @@ namespace MQTTnet.Implementations public Stream ReceiveStream { get; private set; } public Stream RawReceiveStream { get; private set; } + public static Func CustomCertificateValidationCallback { get; set; } + public async Task ConnectAsync() { if (_socket == null) @@ -57,7 +59,7 @@ namespace MQTTnet.Implementations if (_options.TlsOptions.UseTls) { - _sslStream = new SslStream(new NetworkStream(_socket, true), false, UserCertificateValidationCallback); + _sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); await _sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(_options), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); } @@ -98,8 +100,13 @@ namespace MQTTnet.Implementations ReceiveStream = new BufferedStream(RawReceiveStream, BufferSize); } - private bool UserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) + private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { + if (CustomCertificateValidationCallback != null) + { + return CustomCertificateValidationCallback(x509Certificate, chain, sslPolicyErrors, _options); + } + if (sslPolicyErrors == SslPolicyErrors.None) { return true; diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index db08751..e60f641 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -40,7 +40,9 @@ namespace MQTTnet.Implementations public Stream SendStream => ReceiveStream; public Stream ReceiveStream { get; private set; } public Stream RawReceiveStream => ReceiveStream; - + + public static Func CustomCertificateValidationCallback { get; set; } + public async Task ConnectAsync() { if (_socket == null) @@ -52,7 +54,7 @@ namespace MQTTnet.Implementations if (_options.TlsOptions.UseTls) { - _sslStream = new SslStream(new NetworkStream(_socket, true), false, UserCertificateValidationCallback); + _sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); ReceiveStream = _sslStream; await _sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(_options), SslProtocols.Tls12, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); } @@ -77,8 +79,13 @@ namespace MQTTnet.Implementations _sslStream = null; } - private bool UserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) + private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { + if (CustomCertificateValidationCallback != null) + { + return CustomCertificateValidationCallback(x509Certificate, chain, sslPolicyErrors, _options); + } + if (sslPolicyErrors == SslPolicyErrors.None) { return true; diff --git a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs index 228865d..591b739 100644 --- a/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.InteropServices.WindowsRuntime; @@ -32,6 +33,8 @@ namespace MQTTnet.Implementations public Stream ReceiveStream { get; private set; } public Stream RawReceiveStream { get; private set; } + public Func> CustomIgnorableServerCertificateErrorsResolver { get; set; } + public async Task ConnectAsync() { if (_socket == null) @@ -47,23 +50,11 @@ namespace MQTTnet.Implementations { _socket.Control.ClientCertificate = LoadCertificate(_options); - if (_options.TlsOptions.IgnoreCertificateRevocationErrors) - { - _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing); - //_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.Revoked); Not supported. - _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationFailure); - } - - if (_options.TlsOptions.IgnoreCertificateChainErrors) + foreach (var ignorableChainValidationResult in ResolveIgnorableServerCertificateErrors()) { - _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain); + _socket.Control.IgnorableServerCertificateErrors.Add(ignorableChainValidationResult); } - if (_options.TlsOptions.AllowUntrustedCertificates) - { - _socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.Untrusted); - } - await _socket.ConnectAsync(new HostName(_options.Server), _options.GetPort().ToString(), SocketProtectionLevel.Tls12); } @@ -112,5 +103,34 @@ namespace MQTTnet.Implementations return new Certificate(options.TlsOptions.Certificates.First().AsBuffer()); } + + private IEnumerable ResolveIgnorableServerCertificateErrors() + { + if (CustomIgnorableServerCertificateErrorsResolver != null) + { + return CustomIgnorableServerCertificateErrorsResolver(_options); + } + + var result = new List(); + + if (_options.TlsOptions.IgnoreCertificateRevocationErrors) + { + result.Add(ChainValidationResult.RevocationInformationMissing); + //_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.Revoked); Not supported. + result.Add(ChainValidationResult.RevocationFailure); + } + + if (_options.TlsOptions.IgnoreCertificateChainErrors) + { + result.Add(ChainValidationResult.IncompleteChain); + } + + if (_options.TlsOptions.AllowUntrustedCertificates) + { + result.Add(ChainValidationResult.Untrusted); + } + + return result; + } } } \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 26a48e5..0e786df 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -183,6 +183,7 @@ namespace MQTTnet.Core.Tests var c1 = await serverAdapter.ConnectTestClient(s, "c1"); await c1.PublishAsync(new MqttApplicationMessage("retained", new byte[3], MqttQualityOfServiceLevel.AtLeastOnce, true)); + await c1.PublishAsync(new MqttApplicationMessage("retained", new byte[0], MqttQualityOfServiceLevel.AtLeastOnce, true)); await c1.DisconnectAsync(); var c2 = await serverAdapter.ConnectTestClient(s, "c2"); @@ -194,7 +195,7 @@ namespace MQTTnet.Core.Tests await s.StopAsync(); - Assert.AreEqual(1, receivedMessagesCount); + Assert.AreEqual(0, receivedMessagesCount); } [TestMethod]