From ad128c7889d039b81c3748547e73998e2535bbbc Mon Sep 17 00:00:00 2001 From: Christian Kratky Date: Fri, 3 Apr 2020 22:52:55 +0200 Subject: [PATCH] Improve socket handling. --- .../Implementations/CrossPlatformSocket.cs | 229 ++++++++++++++++++ .../MQTTnet/Implementations/MqttTcpChannel.cs | 38 ++- .../Implementations/MqttTcpServerListener.cs | 14 +- .../PlatformAbstractionLayer.cs | 87 +------ Source/MQTTnet/Server/MqttClientConnection.cs | 28 +-- .../Server/MqttClientSessionsManager.cs | 11 +- .../Server/MqttClientSubscriptionsManager.cs | 18 +- .../Server/MqttServerEventDispatcher.cs | 82 ++++--- .../CrossPlatformSocket_Tests.cs | 74 ++++++ .../MqttTcpChannel_Tests.cs | 18 +- Tests/MQTTnet.Core.Tests/Server_Tests.cs | 33 +-- 11 files changed, 438 insertions(+), 194 deletions(-) create mode 100644 Source/MQTTnet/Implementations/CrossPlatformSocket.cs create mode 100644 Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs diff --git a/Source/MQTTnet/Implementations/CrossPlatformSocket.cs b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs new file mode 100644 index 0000000..4d7f6c9 --- /dev/null +++ b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs @@ -0,0 +1,229 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Implementations +{ + public sealed class CrossPlatformSocket : IDisposable + { + readonly Socket _socket; + + public CrossPlatformSocket(AddressFamily addressFamily) + { + _socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp); + } + + public CrossPlatformSocket() + { + // Having this contructor is important because avoiding the address family as parameter + // will make use of dual mode in the .net framework. + _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + } + + public CrossPlatformSocket(Socket socket) + { + _socket = socket ?? throw new ArgumentNullException(nameof(socket)); + } + + public bool NoDelay + { + get + { + return (int)_socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay) > 0; + } + + set + { + _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, value ? 1 : 0); + } + } + + public bool DualMode + { + get + { + return _socket.DualMode; + } + + set + { + _socket.DualMode = value; + } + } + + public int ReceiveBufferSize + { + get + { + return _socket.ReceiveBufferSize; + } + + set + { + _socket.ReceiveBufferSize = value; + } + } + + public int SendBufferSize + { + get + { + return _socket.SendBufferSize; + } + + set + { + _socket.SendBufferSize = value; + } + } + + public EndPoint RemoteEndPoint + { + get + { + return _socket.RemoteEndPoint; + } + } + + public bool ReuseAddress + { + get + { + return (int)_socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress) != 0; + } + + set + { + _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, value ? 1 : 0); + } + } + + public async Task AcceptAsync() + { + try + { +#if NET452 || NET461 + var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); + return new CrossPlatformSocket(clientSocket); +#else + var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); + return new CrossPlatformSocket(clientSocket); +#endif + } + catch (ObjectDisposedException) + { + // This will happen when _socket.EndAccept gets called by Task library but the socket is already disposed. + return null; + } + } + + public void Bind(EndPoint localEndPoint) + { + if (localEndPoint is null) throw new ArgumentNullException(nameof(localEndPoint)); + + _socket.Bind(localEndPoint); + } + + public void Listen(int connectionBacklog) + { + _socket.Listen(connectionBacklog); + } + + public async Task ConnectAsync(string host, int port, CancellationToken cancellationToken) + { + if (host is null) throw new ArgumentNullException(nameof(host)); + + try + { + // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 + using (cancellationToken.Register(() => _socket.Dispose())) + { + cancellationToken.ThrowIfCancellationRequested(); + +#if NET452 || NET461 + await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, host, port, null).ConfigureAwait(false); +#else + await _socket.ConnectAsync(host, port).ConfigureAwait(false); +#endif + } + } + catch (ObjectDisposedException) + { + // This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed. + } + } + + public async Task SendAsync(ArraySegment buffer, SocketFlags socketFlags) + { + try + { +#if NET452 || NET461 + await Task.Factory.FromAsync(SocketWrapper.BeginSend, _socket.EndSend, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false); +#else + await _socket.SendAsync(buffer, socketFlags).ConfigureAwait(false); +#endif + } + catch (ObjectDisposedException) + { + // This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed. + } + } + + public async Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags) + { + try + { +#if NET452 || NET461 + return await Task.Factory.FromAsync(SocketWrapper.BeginReceive, _socket.EndReceive, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false); +#else + return await _socket.ReceiveAsync(buffer, socketFlags).ConfigureAwait(false); +#endif + } + catch (ObjectDisposedException) + { + // This will happen when _socket.EndReceive gets called by Task library but the socket is already disposed. + return -1; + } + } + + public NetworkStream GetStream() + { + return new NetworkStream(_socket, true); + } + + public void Dispose() + { + _socket?.Dispose(); + } + +#if NET452 || NET461 + class SocketWrapper + { + readonly Socket _socket; + readonly ArraySegment _buffer; + readonly SocketFlags _socketFlags; + + public SocketWrapper(Socket socket, ArraySegment buffer, SocketFlags socketFlags) + { + _socket = socket; + _buffer = buffer; + _socketFlags = socketFlags; + } + + public static IAsyncResult BeginSend(AsyncCallback callback, object state) + { + var socketWrapper = (SocketWrapper)state; + return socketWrapper._socket.BeginSend(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state); + } + + public static IAsyncResult BeginReceive(AsyncCallback callback, object state) + { + var socketWrapper = (SocketWrapper)state; + return socketWrapper._socket.BeginReceive(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state); + } + } +#endif + } +} diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 71050e1..021a7d4 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -46,15 +46,15 @@ namespace MQTTnet.Implementations public async Task ConnectAsync(CancellationToken cancellationToken) { - Socket socket; + CrossPlatformSocket socket; if (_options.AddressFamily == AddressFamily.Unspecified) { - socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + socket = new CrossPlatformSocket(); } else { - socket = new Socket(_options.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + socket = new CrossPlatformSocket(_options.AddressFamily); } socket.ReceiveBufferSize = _options.BufferSize; @@ -69,20 +69,24 @@ namespace MQTTnet.Implementations socket.DualMode = _options.DualMode.Value; } - // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 - using (cancellationToken.Register(() => socket.Dispose())) - { - await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false); - } + await socket.ConnectAsync(_options.Server, _options.GetPort(), cancellationToken).ConfigureAwait(false); - var networkStream = new NetworkStream(socket, true); + var networkStream = socket.GetStream(); if (_options.TlsOptions.UseTls) { var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); - _stream = sslStream; + try + { + await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); + } + catch + { + sslStream.Dispose(); + throw; + } - await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); + _stream = sslStream; } else { @@ -107,17 +111,14 @@ namespace MQTTnet.Implementations // Workaround for: https://github.com/dotnet/corefx/issues/24430 using (cancellationToken.Register(Dispose)) { - if (cancellationToken.IsCancellationRequested) - { - return 0; - } + cancellationToken.ThrowIfCancellationRequested(); return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } } catch (ObjectDisposedException) { - return 0; + return -1; } catch (IOException exception) { @@ -139,10 +140,7 @@ namespace MQTTnet.Implementations // Workaround for: https://github.com/dotnet/corefx/issues/24430 using (cancellationToken.Register(Dispose)) { - if (cancellationToken.IsCancellationRequested) - { - return; - } + cancellationToken.ThrowIfCancellationRequested(); await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index 84cd7b9..83ef80f 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -23,7 +23,7 @@ namespace MQTTnet.Implementations readonly MqttServerTlsTcpEndpointOptions _tlsOptions; readonly X509Certificate2 _tlsCertificate; - private Socket _socket; + private CrossPlatformSocket _socket; private IPEndPoint _localEndPoint; public MqttTcpServerListener( @@ -59,18 +59,18 @@ namespace MQTTnet.Implementations _logger.Info($"Starting TCP listener for {_localEndPoint} TLS={_tlsCertificate != null}."); - _socket = new Socket(_addressFamily, SocketType.Stream, ProtocolType.Tcp); + _socket = new CrossPlatformSocket(_addressFamily); // Usage of socket options is described here: https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.setsocketoption?view=netcore-2.2 if (_options.ReuseAddress) { - _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); + _socket.ReuseAddress = true; } if (_options.NoDelay) { - _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); + _socket.NoDelay = true; } _socket.Bind(_localEndPoint); @@ -107,7 +107,7 @@ namespace MQTTnet.Implementations { try { - var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false); + var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); if (clientSocket == null) { continue; @@ -135,7 +135,7 @@ namespace MQTTnet.Implementations } } - async Task TryHandleClientConnectionAsync(Socket clientSocket) + async Task TryHandleClientConnectionAsync(CrossPlatformSocket clientSocket) { Stream stream = null; string remoteEndPoint = null; @@ -151,7 +151,7 @@ namespace MQTTnet.Implementations clientSocket.NoDelay = _options.NoDelay; - stream = new NetworkStream(clientSocket, true); + stream = clientSocket.GetStream(); X509Certificate2 clientCertificate = null; diff --git a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs index 80c0890..0b683dc 100644 --- a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs +++ b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs @@ -1,94 +1,9 @@ -using System; -using System.Net; -using System.Net.Sockets; -using System.Threading.Tasks; +using System.Threading.Tasks; namespace MQTTnet.Implementations { public static class PlatformAbstractionLayer { - // TODO: Consider creating primitives like "MqttNetSocket" which will wrap all required methods and do the platform stuff. - public static async Task AcceptAsync(Socket socket) - { -#if NET452 || NET461 - try - { - return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false); - } - catch (ObjectDisposedException) - { - return null; - } -#else - return await socket.AcceptAsync().ConfigureAwait(false); -#endif - } - - - public static Task ConnectAsync(Socket socket, IPAddress ip, int port) - { -#if NET452 || NET461 - return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null); -#else - return socket.ConnectAsync(ip, port); -#endif - } - - public static Task ConnectAsync(Socket socket, string host, int port) - { -#if NET452 || NET461 - return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null); -#else - return socket.ConnectAsync(host, port); -#endif - } - -#if NET452 || NET461 - public class SocketWrapper - { - private readonly Socket _socket; - private readonly ArraySegment _buffer; - private readonly SocketFlags _socketFlags; - - public SocketWrapper(Socket socket, ArraySegment buffer, SocketFlags socketFlags) - { - _socket = socket; - _buffer = buffer; - _socketFlags = socketFlags; - } - - public static IAsyncResult BeginSend(AsyncCallback callback, object state) - { - var real = (SocketWrapper)state; - return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); - } - - public static IAsyncResult BeginReceive(AsyncCallback callback, object state) - { - var real = (SocketWrapper)state; - return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); - } - } -#endif - - public static Task SendAsync(Socket socket, ArraySegment buffer, SocketFlags socketFlags) - { -#if NET452 || NET461 - return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags)); -#else - return socket.SendAsync(buffer, socketFlags); -#endif - } - - public static Task ReceiveAsync(Socket socket, ArraySegment buffer, SocketFlags socketFlags) - { -#if NET452 || NET461 - return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags)); -#else - return socket.ReceiveAsync(buffer, socketFlags); -#endif - } - public static Task CompletedTask { get diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index 54475d0..70beba8 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -15,7 +15,7 @@ using System.Threading.Tasks; namespace MQTTnet.Server { - public class MqttClientConnection : IDisposable + public sealed class MqttClientConnection : IDisposable { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); @@ -124,7 +124,7 @@ namespace MQTTnet.Server return _packageReceiverTask; } - private async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) + async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) { var disconnectType = MqttClientDisconnectType.NotClean; try @@ -251,12 +251,12 @@ namespace MQTTnet.Server return disconnectType; } - private void StopInternal() + void StopInternal() { _cancellationToken.Cancel(false); } - private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) + async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) { var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); foreach (var applicationMessage in retainedMessages) @@ -265,7 +265,7 @@ namespace MQTTnet.Server } } - private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) + async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) { // TODO: Let the channel adapter create the packet. var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); @@ -281,14 +281,14 @@ namespace MQTTnet.Server await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); } - private async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket) + async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket) { // TODO: Let the channel adapter create the packet. var unsubscribeResult = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); await SendAsync(unsubscribeResult).ConfigureAwait(false); } - private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) + Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) { Interlocked.Increment(ref _sentApplicationMessagesCount); @@ -313,7 +313,7 @@ namespace MQTTnet.Server } } - private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket) + Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket) { var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); @@ -322,7 +322,7 @@ namespace MQTTnet.Server return Task.FromResult(0); } - private Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) + Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) { var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); _sessionsManager.DispatchApplicationMessage(applicationMessage, this); @@ -331,7 +331,7 @@ namespace MQTTnet.Server return SendAsync(pubAckPacket); } - private Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket) + Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket) { var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); _sessionsManager.DispatchApplicationMessage(applicationMessage, this); @@ -345,7 +345,7 @@ namespace MQTTnet.Server return SendAsync(pubRecPacket); } - private async Task SendPendingPacketsAsync(CancellationToken cancellationToken) + async Task SendPendingPacketsAsync(CancellationToken cancellationToken) { MqttQueuedApplicationMessage queuedApplicationMessage = null; MqttPublishPacket publishPacket = null; @@ -459,7 +459,7 @@ namespace MQTTnet.Server } } - private async Task SendAsync(MqttBasePacket packet) + async Task SendAsync(MqttBasePacket packet) { await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, _cancellationToken.Token).ConfigureAwait(false); @@ -471,12 +471,12 @@ namespace MQTTnet.Server } } - private void OnAdapterReadingPacketCompleted() + void OnAdapterReadingPacketCompleted() { _keepAliveMonitor?.Resume(); } - private void OnAdapterReadingPacketStarted() + void OnAdapterReadingPacketStarted() { _keepAliveMonitor?.Pause(); } diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 285f63e..08568d8 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -207,7 +207,7 @@ namespace MQTTnet.Server applicationMessage = interceptorContext.ApplicationMessage; } - await _eventDispatcher.HandleApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); + await _eventDispatcher.SafeNotifyApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); if (applicationMessage.Retain) { @@ -237,7 +237,7 @@ namespace MQTTnet.Server string clientId = null; var clientWasConnected = true; - MqttConnectPacket connectPacket = null; + MqttConnectPacket connectPacket; try { @@ -259,8 +259,6 @@ namespace MQTTnet.Server var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); - clientId = connectPacket.ClientId; - if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { clientWasConnected = false; @@ -272,9 +270,10 @@ namespace MQTTnet.Server return; } + clientId = connectPacket.ClientId; var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); - await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); + await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false); disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); } @@ -303,7 +302,7 @@ namespace MQTTnet.Server if (clientId != null) { - await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); } } } diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index deeadf4..59e9f34 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -1,9 +1,9 @@ -using System; +using MQTTnet.Packets; +using MQTTnet.Protocol; +using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using MQTTnet.Packets; -using MQTTnet.Protocol; namespace MQTTnet.Server { @@ -67,7 +67,7 @@ namespace MQTTnet.Server _subscriptions[finalTopicFilter.Topic] = finalTopicFilter; } - await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false); + await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false); } } @@ -83,7 +83,7 @@ namespace MQTTnet.Server var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); if (!interceptorContext.AcceptSubscription) { - continue; + continue; } if (interceptorContext.AcceptSubscription) @@ -93,7 +93,7 @@ namespace MQTTnet.Server _subscriptions[topicFilter.Topic] = topicFilter; } - await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); + await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); } } } @@ -131,9 +131,9 @@ namespace MQTTnet.Server foreach (var topicFilter in unsubscribePacket.TopicFilters) { - await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); + await _eventDispatcher.SafeNotifyClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); } - + return unsubAckPacket; } @@ -152,7 +152,7 @@ namespace MQTTnet.Server lock (_subscriptions) { _subscriptions.Remove(topicFilter); - } + } } } diff --git a/Source/MQTTnet/Server/MqttServerEventDispatcher.cs b/Source/MQTTnet/Server/MqttServerEventDispatcher.cs index 3eb6b85..3bb3768 100644 --- a/Source/MQTTnet/Server/MqttServerEventDispatcher.cs +++ b/Source/MQTTnet/Server/MqttServerEventDispatcher.cs @@ -7,7 +7,7 @@ namespace MQTTnet.Server { public class MqttServerEventDispatcher { - private readonly IMqttNetLogger _logger; + readonly IMqttNetLogger _logger; public MqttServerEventDispatcher(IMqttNetLogger logger) { @@ -24,18 +24,25 @@ namespace MQTTnet.Server public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } - public Task HandleClientConnectedAsync(string clientId) + public async Task SafeNotifyClientConnectedAsync(string clientId) { - var handler = ClientConnectedHandler; - if (handler == null) + try { - return Task.FromResult(0); - } + var handler = ClientConnectedHandler; + if (handler == null) + { + return; + } - return handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId)); + await handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId)).ConfigureAwait(false); + } + catch (Exception exception) + { + _logger.Error(exception, "Error while handling custom 'ClientConnected' event."); + } } - public async Task TryHandleClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType) + public async Task SafeNotifyClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType) { try { @@ -49,41 +56,62 @@ namespace MQTTnet.Server } catch (Exception exception) { - _logger.Error(exception, "Error while handling 'ClientDisconnected' event."); + _logger.Error(exception, "Error while handling custom 'ClientDisconnected' event."); } } - public Task HandleClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter) + public async Task SafeNotifyClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter) { - var handler = ClientSubscribedTopicHandler; - if (handler == null) + try { - return Task.FromResult(0); - } + var handler = ClientSubscribedTopicHandler; + if (handler == null) + { + return; + } - return handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter)); + await handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false); + } + catch (Exception exception) + { + _logger.Error(exception, "Error while handling custom 'ClientSubscribedTopic' event."); + } } - public Task HandleClientUnsubscribedTopicAsync(string clientId, string topicFilter) + public async Task SafeNotifyClientUnsubscribedTopicAsync(string clientId, string topicFilter) { - var handler = ClientUnsubscribedTopicHandler; - if (handler == null) + try { - return Task.FromResult(0); - } + var handler = ClientUnsubscribedTopicHandler; + if (handler == null) + { + return; + } - return handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter)); + await handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false); + } + catch (Exception exception) + { + _logger.Error(exception, "Error while handling custom 'ClientUnsubscribedTopic' event."); + } } - public Task HandleApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage) + public async Task SafeNotifyApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage) { - var handler = ApplicationMessageReceivedHandler; - if (handler == null) + try { - return Task.FromResult(0); - } + var handler = ApplicationMessageReceivedHandler; + if (handler == null) + { + return; + } - return handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage)); + await handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage)).ConfigureAwait(false); ; + } + catch (Exception exception) + { + _logger.Error(exception, "Error while handling custom 'ApplicationMessageReceived' event."); + } } } } diff --git a/Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs b/Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs new file mode 100644 index 0000000..3e3455e --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs @@ -0,0 +1,74 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Implementations; +using System; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Tests +{ + [TestClass] + public class CrossPlatformSocket_Tests + { + [TestMethod] + public async Task Connect_Send_Receive() + { + var crossPlatformSocket = new CrossPlatformSocket(); + await crossPlatformSocket.ConnectAsync("www.google.de", 80, CancellationToken.None); + + var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n"); + await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), System.Net.Sockets.SocketFlags.None); + + var buffer = new byte[1024]; + var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment(buffer), System.Net.Sockets.SocketFlags.None); + crossPlatformSocket.Dispose(); + + var responseText = Encoding.UTF8.GetString(buffer, 0, length); + + Assert.IsTrue(responseText.Contains("HTTP/1.1 200 OK")); + } + + [TestMethod] + public async Task Try_Connect_Invalid_Host() + { + var crossPlatformSocket = new CrossPlatformSocket(); + + var cancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(3)); + cancellationToken.Token.Register(() => crossPlatformSocket.Dispose()); + + await crossPlatformSocket.ConnectAsync("www.google.de", 1234, CancellationToken.None); + } + + //[TestMethod] + //public async Task Use_Disconnected_Socket() + //{ + // var crossPlatformSocket = new CrossPlatformSocket(); + + // await crossPlatformSocket.ConnectAsync("www.google.de", 80); + + // var requestBuffer = Encoding.UTF8.GetBytes("GET /wrong_uri HTTP/1.1\r\nConnection: close\r\n\r\n"); + // await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), System.Net.Sockets.SocketFlags.None); + + // var buffer = new byte[64000]; + // var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment(buffer), System.Net.Sockets.SocketFlags.None); + + // await Task.Delay(500); + + // await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), System.Net.Sockets.SocketFlags.None); + //} + + [TestMethod] + public async Task Set_Options() + { + var crossPlatformSocket = new CrossPlatformSocket(); + + Assert.IsFalse(crossPlatformSocket.ReuseAddress); + crossPlatformSocket.ReuseAddress = true; + Assert.IsTrue(crossPlatformSocket.ReuseAddress); + + Assert.IsFalse(crossPlatformSocket.NoDelay); + crossPlatformSocket.NoDelay = true; + Assert.IsTrue(crossPlatformSocket.NoDelay); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs index a4b0ca7..6e78dec 100644 --- a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs @@ -1,10 +1,10 @@ -using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Implementations; +using System; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Implementations; namespace MQTTnet.Tests { @@ -15,7 +15,7 @@ namespace MQTTnet.Tests public async Task Dispose_Channel_While_Used() { var ct = new CancellationTokenSource(); - var serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork); try { @@ -28,18 +28,18 @@ namespace MQTTnet.Tests { while (!ct.IsCancellationRequested) { - var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket); + var client = await serverSocket.AcceptAsync(); var data = new byte[] { 128 }; - await PlatformAbstractionLayer.SendAsync(client, new ArraySegment(data), SocketFlags.None); + await client.SendAsync(new ArraySegment(data), SocketFlags.None); } }, ct.Token); - var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001); + var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork); + await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None); await Task.Delay(100, ct.Token); - var tcpChannel = new MqttTcpChannel(new NetworkStream(clientSocket, true), "test", null); + var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null); var buffer = new byte[1]; await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token); diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs index 8121f41..e40eba0 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -904,14 +904,13 @@ namespace MQTTnet.Tests await testEnvironment.StartServerAsync(serverOptions); - var connectingFailedException = await Assert.ThrowsExceptionAsync(() => testEnvironment.ConnectClientAsync()); Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode); } } + Dictionary _connected; - private Dictionary _connected; private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs) { if (_connected.ContainsKey(eventArgs.ClientId)) @@ -919,6 +918,7 @@ namespace MQTTnet.Tests eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; return; } + _connected[eventArgs.ClientId] = true; eventArgs.ReasonCode = MqttConnectReasonCode.Success; return; @@ -1053,6 +1053,12 @@ namespace MQTTnet.Tests // Connect client with same client ID. Should disconnect existing client. var c2 = await testEnvironment.ConnectClientAsync(clientOptionsBuilder); + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + + Assert.AreEqual("cdc", flow); + c2.UseApplicationMessageReceivedHandler(_ => { lock (events) @@ -1061,15 +1067,10 @@ namespace MQTTnet.Tests } }); - c2.SubscribeAsync("topic").Wait(); - - await Task.Delay(500); - - flow = string.Join(string.Empty, events); - Assert.AreEqual("cdc", flow); + await c2.SubscribeAsync("topic"); // r - c2.PublishAsync("topic").Wait(); + await c2.PublishAsync("topic"); await Task.Delay(500); @@ -1149,15 +1150,15 @@ namespace MQTTnet.Tests { await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); - var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); + var client = new CrossPlatformSocket(AddressFamily.InterNetwork); + await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); // Don't send anything. The server should close the connection. await Task.Delay(TimeSpan.FromSeconds(3)); try { - var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment(new byte[10]), SocketFlags.Partial); + var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); if (receivedBytes == 0) { return; @@ -1180,17 +1181,17 @@ namespace MQTTnet.Tests // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state // forever. This is security related. - var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); + var client = new CrossPlatformSocket(AddressFamily.InterNetwork); + await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); var buffer = Encoding.UTF8.GetBytes("Garbage"); - client.Send(buffer, buffer.Length, SocketFlags.None); + await client.SendAsync(new ArraySegment(buffer), SocketFlags.None); await Task.Delay(TimeSpan.FromSeconds(3)); try { - var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment(new byte[10]), SocketFlags.Partial); + var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); if (receivedBytes == 0) { return;