diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 9b2ba56..71050e1 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -1,25 +1,24 @@ #if !WINDOWS_UWP +using MQTTnet.Channel; +using MQTTnet.Client.Options; using System; -using System.Net.Security; -using System.Net.Sockets; -using System.Security.Cryptography.X509Certificates; -using System.Threading.Tasks; using System.IO; using System.Linq; +using System.Net.Security; +using System.Net.Sockets; using System.Runtime.ExceptionServices; +using System.Security.Cryptography.X509Certificates; using System.Threading; -using MQTTnet.Channel; -using MQTTnet.Client.Options; -using MQTTnet.Internal; +using System.Threading.Tasks; namespace MQTTnet.Implementations { - public class MqttTcpChannel : Disposable, IMqttChannel + public sealed class MqttTcpChannel : IDisposable, IMqttChannel { - private readonly IMqttClientOptions _clientOptions; - private readonly MqttClientTcpOptions _options; + readonly IMqttClientOptions _clientOptions; + readonly MqttClientTcpOptions _options; - private Stream _stream; + Stream _stream; public MqttTcpChannel(IMqttClientOptions clientOptions) { @@ -69,7 +68,7 @@ namespace MQTTnet.Implementations // of the actual value. socket.DualMode = _options.DualMode.Value; } - + // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 using (cancellationToken.Register(() => socket.Dispose())) { @@ -83,7 +82,7 @@ namespace MQTTnet.Implementations var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); _stream = sslStream; - await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); + await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); } else { @@ -95,12 +94,14 @@ namespace MQTTnet.Implementations public Task DisconnectAsync(CancellationToken cancellationToken) { - Cleanup(); + Dispose(); return Task.FromResult(0); } public async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + if (buffer is null) throw new ArgumentNullException(nameof(buffer)); + try { // Workaround for: https://github.com/dotnet/corefx/issues/24430 @@ -131,6 +132,8 @@ namespace MQTTnet.Implementations public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + if (buffer is null) throw new ArgumentNullException(nameof(buffer)); + try { // Workaround for: https://github.com/dotnet/corefx/issues/24430 @@ -159,7 +162,7 @@ namespace MQTTnet.Implementations } } - private void Cleanup() + public void Dispose() { // When the stream is disposed it will also close the socket and this will also dispose it. // So there is no need to dispose the socket again. @@ -178,16 +181,7 @@ namespace MQTTnet.Implementations _stream = null; } - protected override void Dispose(bool disposing) - { - if (disposing) - { - Cleanup(); - } - base.Dispose(disposing); - } - - private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) + bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { if (_options.TlsOptions.CertificateValidationCallback != null) { @@ -218,7 +212,7 @@ namespace MQTTnet.Implementations return _options.TlsOptions.AllowUntrustedCertificates; } - private X509CertificateCollection LoadCertificates() + X509CertificateCollection LoadCertificates() { var certificates = new X509CertificateCollection(); if (_options.TlsOptions.Certificates == null) diff --git a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs index ee9057a..80c0890 100644 --- a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs +++ b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs @@ -7,6 +7,7 @@ namespace MQTTnet.Implementations { public static class PlatformAbstractionLayer { + // TODO: Consider creating primitives like "MqttNetSocket" which will wrap all required methods and do the platform stuff. public static async Task AcceptAsync(Socket socket) { #if NET452 || NET461 @@ -90,7 +91,7 @@ namespace MQTTnet.Implementations public static Task CompletedTask { - get + get { #if NET452 return Task.FromResult(0); diff --git a/Source/MQTTnet/Internal/Disposable.cs b/Source/MQTTnet/Internal/Disposable.cs index 2ce3423..e9b05ea 100644 --- a/Source/MQTTnet/Internal/Disposable.cs +++ b/Source/MQTTnet/Internal/Disposable.cs @@ -2,32 +2,20 @@ namespace MQTTnet.Internal { - public class Disposable : IDisposable + public abstract class Disposable : IDisposable { - protected bool IsDisposed => _isDisposed; + protected bool IsDisposed { get; private set; } = false; protected void ThrowIfDisposed() { - if (_isDisposed) + if (IsDisposed) { throw new ObjectDisposedException(GetType().Name); } } - - #region IDisposable Support - - private bool _isDisposed = false; // To detect redundant calls - protected virtual void Dispose(bool disposing) { - if (disposing) - { - // TODO: dispose managed state (managed objects). - } - - // TODO: free unmanaged resources (unmanaged objects) and override a finalizer below. - // TODO: set large fields to null. } // TODO: override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. @@ -40,18 +28,17 @@ namespace MQTTnet.Internal // This code added to correctly implement the disposable pattern. public void Dispose() { - if (_isDisposed) + // Do not change this code. Put cleanup code in Dispose(bool disposing) above. + + if (IsDisposed) { return; } - _isDisposed = true; + IsDisposed = true; - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. Dispose(true); - // TODO: uncomment the following line if the finalizer is overridden above. - // GC.SuppressFinalize(this); + GC.SuppressFinalize(this); } - #endregion } } diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 28c163d..6f8ea20 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -1,15 +1,16 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Adapter; +using MQTTnet.Adapter; using MQTTnet.Diagnostics; +using MQTTnet.Exceptions; using MQTTnet.Formatter; using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Server.Status; +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet.Server { @@ -236,12 +237,23 @@ namespace MQTTnet.Server string clientId = null; var clientWasConnected = true; + MqttConnectPacket connectPacket = null; + try { - var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); - if (!(firstPacket is MqttConnectPacket connectPacket)) + try + { + var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); + connectPacket = firstPacket as MqttConnectPacket; + if (connectPacket == null) + { + _logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", channelAdapter.Endpoint); + return; + } + } + catch (MqttCommunicationTimedOutException) { - _logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", channelAdapter.Endpoint); + _logger.Warning(null, "Client '{0}' connected but did not sent a CONNECT packet.", channelAdapter.Endpoint); return; } diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs index 3ebdaa6..6923c69 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -1,11 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Net.Sockets; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Adapter; using MQTTnet.Client; using MQTTnet.Client.Connecting; @@ -17,6 +10,13 @@ using MQTTnet.Implementations; using MQTTnet.Protocol; using MQTTnet.Server; using MQTTnet.Tests.Mockups; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet.Tests { @@ -54,7 +54,7 @@ namespace MQTTnet.Tests MqttQualityOfServiceLevel.AtMostOnce, "A/B/C", MqttQualityOfServiceLevel.AtMostOnce, - 1, + 1, TestContext); } @@ -1050,6 +1050,7 @@ namespace MQTTnet.Tests Assert.AreEqual("c", flow); // dc + // Connect client with same client ID. Should disconnect existing client. var c2 = await testEnvironment.ConnectClientAsync(clientOptions); c2.UseApplicationMessageReceivedHandler(_ => @@ -1058,8 +1059,8 @@ namespace MQTTnet.Tests { events.Add("r"); } - }); + c2.SubscribeAsync("topic").Wait(); await Task.Delay(500); @@ -1075,12 +1076,11 @@ namespace MQTTnet.Tests flow = string.Join(string.Empty, events); Assert.AreEqual("cdcr", flow); - // nothing Assert.AreEqual(false, c1.IsConnected); await c1.DisconnectAsync(); - Assert.AreEqual (false, c1.IsConnected); + Assert.AreEqual(false, c1.IsConnected); await Task.Delay(500); @@ -1141,7 +1141,7 @@ namespace MQTTnet.Tests await testEnvironment.ConnectClientAsync(); } } - + [TestMethod] public async Task Close_Idle_Connection() { @@ -1182,7 +1182,7 @@ namespace MQTTnet.Tests // forever. This is security related. var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); - + var buffer = Encoding.UTF8.GetBytes("Garbage"); client.Send(buffer, buffer.Length, SocketFlags.None);