Browse Source

Improve socket handling.

release/3.x.x
Christian Kratky 4 years ago
parent
commit
ad128c7889
11 changed files with 438 additions and 194 deletions
  1. +229
    -0
      Source/MQTTnet/Implementations/CrossPlatformSocket.cs
  2. +18
    -20
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  3. +7
    -7
      Source/MQTTnet/Implementations/MqttTcpServerListener.cs
  4. +1
    -86
      Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs
  5. +14
    -14
      Source/MQTTnet/Server/MqttClientConnection.cs
  6. +5
    -6
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  7. +9
    -9
      Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs
  8. +55
    -27
      Source/MQTTnet/Server/MqttServerEventDispatcher.cs
  9. +74
    -0
      Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs
  10. +9
    -9
      Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs
  11. +17
    -16
      Tests/MQTTnet.Core.Tests/Server_Tests.cs

+ 229
- 0
Source/MQTTnet/Implementations/CrossPlatformSocket.cs View File

@@ -0,0 +1,229 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.Implementations
{
public sealed class CrossPlatformSocket : IDisposable
{
readonly Socket _socket;

public CrossPlatformSocket(AddressFamily addressFamily)
{
_socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp);
}

public CrossPlatformSocket()
{
// Having this contructor is important because avoiding the address family as parameter
// will make use of dual mode in the .net framework.
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}

public CrossPlatformSocket(Socket socket)
{
_socket = socket ?? throw new ArgumentNullException(nameof(socket));
}

public bool NoDelay
{
get
{
return (int)_socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay) > 0;
}

set
{
_socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, value ? 1 : 0);
}
}

public bool DualMode
{
get
{
return _socket.DualMode;
}

set
{
_socket.DualMode = value;
}
}

public int ReceiveBufferSize
{
get
{
return _socket.ReceiveBufferSize;
}

set
{
_socket.ReceiveBufferSize = value;
}
}

public int SendBufferSize
{
get
{
return _socket.SendBufferSize;
}

set
{
_socket.SendBufferSize = value;
}
}

public EndPoint RemoteEndPoint
{
get
{
return _socket.RemoteEndPoint;
}
}

public bool ReuseAddress
{
get
{
return (int)_socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress) != 0;
}

set
{
_socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, value ? 1 : 0);
}
}

public async Task<CrossPlatformSocket> AcceptAsync()
{
try
{
#if NET452 || NET461
var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false);
return new CrossPlatformSocket(clientSocket);
#else
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false);
return new CrossPlatformSocket(clientSocket);
#endif
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndAccept gets called by Task library but the socket is already disposed.
return null;
}
}

public void Bind(EndPoint localEndPoint)
{
if (localEndPoint is null) throw new ArgumentNullException(nameof(localEndPoint));

_socket.Bind(localEndPoint);
}

public void Listen(int connectionBacklog)
{
_socket.Listen(connectionBacklog);
}

public async Task ConnectAsync(string host, int port, CancellationToken cancellationToken)
{
if (host is null) throw new ArgumentNullException(nameof(host));

try
{
// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
cancellationToken.ThrowIfCancellationRequested();

#if NET452 || NET461
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, host, port, null).ConfigureAwait(false);
#else
await _socket.ConnectAsync(host, port).ConfigureAwait(false);
#endif
}
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed.
}
}

public async Task SendAsync(ArraySegment<byte> buffer, SocketFlags socketFlags)
{
try
{
#if NET452 || NET461
await Task.Factory.FromAsync(SocketWrapper.BeginSend, _socket.EndSend, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false);
#else
await _socket.SendAsync(buffer, socketFlags).ConfigureAwait(false);
#endif
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed.
}
}

public async Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFlags)
{
try
{
#if NET452 || NET461
return await Task.Factory.FromAsync(SocketWrapper.BeginReceive, _socket.EndReceive, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false);
#else
return await _socket.ReceiveAsync(buffer, socketFlags).ConfigureAwait(false);
#endif
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndReceive gets called by Task library but the socket is already disposed.
return -1;
}
}

public NetworkStream GetStream()
{
return new NetworkStream(_socket, true);
}

public void Dispose()
{
_socket?.Dispose();
}

#if NET452 || NET461
class SocketWrapper
{
readonly Socket _socket;
readonly ArraySegment<byte> _buffer;
readonly SocketFlags _socketFlags;

public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
_socket = socket;
_buffer = buffer;
_socketFlags = socketFlags;
}

public static IAsyncResult BeginSend(AsyncCallback callback, object state)
{
var socketWrapper = (SocketWrapper)state;
return socketWrapper._socket.BeginSend(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state);
}

public static IAsyncResult BeginReceive(AsyncCallback callback, object state)
{
var socketWrapper = (SocketWrapper)state;
return socketWrapper._socket.BeginReceive(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state);
}
}
#endif
}
}

+ 18
- 20
Source/MQTTnet/Implementations/MqttTcpChannel.cs View File

@@ -46,15 +46,15 @@ namespace MQTTnet.Implementations


public async Task ConnectAsync(CancellationToken cancellationToken) public async Task ConnectAsync(CancellationToken cancellationToken)
{ {
Socket socket;
CrossPlatformSocket socket;


if (_options.AddressFamily == AddressFamily.Unspecified) if (_options.AddressFamily == AddressFamily.Unspecified)
{ {
socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
socket = new CrossPlatformSocket();
} }
else else
{ {
socket = new Socket(_options.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
socket = new CrossPlatformSocket(_options.AddressFamily);
} }


socket.ReceiveBufferSize = _options.BufferSize; socket.ReceiveBufferSize = _options.BufferSize;
@@ -69,20 +69,24 @@ namespace MQTTnet.Implementations
socket.DualMode = _options.DualMode.Value; socket.DualMode = _options.DualMode.Value;
} }


// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => socket.Dispose()))
{
await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false);
}
await socket.ConnectAsync(_options.Server, _options.GetPort(), cancellationToken).ConfigureAwait(false);


var networkStream = new NetworkStream(socket, true);
var networkStream = socket.GetStream();


if (_options.TlsOptions.UseTls) if (_options.TlsOptions.UseTls)
{ {
var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback);
_stream = sslStream;
try
{
await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false);
}
catch
{
sslStream.Dispose();
throw;
}


await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false);
_stream = sslStream;
} }
else else
{ {
@@ -107,17 +111,14 @@ namespace MQTTnet.Implementations
// Workaround for: https://github.com/dotnet/corefx/issues/24430 // Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(Dispose)) using (cancellationToken.Register(Dispose))
{ {
if (cancellationToken.IsCancellationRequested)
{
return 0;
}
cancellationToken.ThrowIfCancellationRequested();


return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
} }
} }
catch (ObjectDisposedException) catch (ObjectDisposedException)
{ {
return 0;
return -1;
} }
catch (IOException exception) catch (IOException exception)
{ {
@@ -139,10 +140,7 @@ namespace MQTTnet.Implementations
// Workaround for: https://github.com/dotnet/corefx/issues/24430 // Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(Dispose)) using (cancellationToken.Register(Dispose))
{ {
if (cancellationToken.IsCancellationRequested)
{
return;
}
cancellationToken.ThrowIfCancellationRequested();


await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
} }


+ 7
- 7
Source/MQTTnet/Implementations/MqttTcpServerListener.cs View File

@@ -23,7 +23,7 @@ namespace MQTTnet.Implementations
readonly MqttServerTlsTcpEndpointOptions _tlsOptions; readonly MqttServerTlsTcpEndpointOptions _tlsOptions;
readonly X509Certificate2 _tlsCertificate; readonly X509Certificate2 _tlsCertificate;


private Socket _socket;
private CrossPlatformSocket _socket;
private IPEndPoint _localEndPoint; private IPEndPoint _localEndPoint;


public MqttTcpServerListener( public MqttTcpServerListener(
@@ -59,18 +59,18 @@ namespace MQTTnet.Implementations


_logger.Info($"Starting TCP listener for {_localEndPoint} TLS={_tlsCertificate != null}."); _logger.Info($"Starting TCP listener for {_localEndPoint} TLS={_tlsCertificate != null}.");


_socket = new Socket(_addressFamily, SocketType.Stream, ProtocolType.Tcp);
_socket = new CrossPlatformSocket(_addressFamily);


// Usage of socket options is described here: https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.setsocketoption?view=netcore-2.2 // Usage of socket options is described here: https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.setsocketoption?view=netcore-2.2


if (_options.ReuseAddress) if (_options.ReuseAddress)
{ {
_socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_socket.ReuseAddress = true;
} }


if (_options.NoDelay) if (_options.NoDelay)
{ {
_socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
_socket.NoDelay = true;
} }


_socket.Bind(_localEndPoint); _socket.Bind(_localEndPoint);
@@ -107,7 +107,7 @@ namespace MQTTnet.Implementations
{ {
try try
{ {
var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false);
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false);
if (clientSocket == null) if (clientSocket == null)
{ {
continue; continue;
@@ -135,7 +135,7 @@ namespace MQTTnet.Implementations
} }
} }


async Task TryHandleClientConnectionAsync(Socket clientSocket)
async Task TryHandleClientConnectionAsync(CrossPlatformSocket clientSocket)
{ {
Stream stream = null; Stream stream = null;
string remoteEndPoint = null; string remoteEndPoint = null;
@@ -151,7 +151,7 @@ namespace MQTTnet.Implementations


clientSocket.NoDelay = _options.NoDelay; clientSocket.NoDelay = _options.NoDelay;


stream = new NetworkStream(clientSocket, true);
stream = clientSocket.GetStream();


X509Certificate2 clientCertificate = null; X509Certificate2 clientCertificate = null;




+ 1
- 86
Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs View File

@@ -1,94 +1,9 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;
using System.Threading.Tasks;


namespace MQTTnet.Implementations namespace MQTTnet.Implementations
{ {
public static class PlatformAbstractionLayer public static class PlatformAbstractionLayer
{ {
// TODO: Consider creating primitives like "MqttNetSocket" which will wrap all required methods and do the platform stuff.
public static async Task<Socket> AcceptAsync(Socket socket)
{
#if NET452 || NET461
try
{
return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false);
}
catch (ObjectDisposedException)
{
return null;
}
#else
return await socket.AcceptAsync().ConfigureAwait(false);
#endif
}


public static Task ConnectAsync(Socket socket, IPAddress ip, int port)
{
#if NET452 || NET461
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null);
#else
return socket.ConnectAsync(ip, port);
#endif
}

public static Task ConnectAsync(Socket socket, string host, int port)
{
#if NET452 || NET461
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null);
#else
return socket.ConnectAsync(host, port);
#endif
}

#if NET452 || NET461
public class SocketWrapper
{
private readonly Socket _socket;
private readonly ArraySegment<byte> _buffer;
private readonly SocketFlags _socketFlags;

public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
_socket = socket;
_buffer = buffer;
_socketFlags = socketFlags;
}

public static IAsyncResult BeginSend(AsyncCallback callback, object state)
{
var real = (SocketWrapper)state;
return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state);
}

public static IAsyncResult BeginReceive(AsyncCallback callback, object state)
{
var real = (SocketWrapper)state;
return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state);
}
}
#endif

public static Task SendAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
#if NET452 || NET461
return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags));
#else
return socket.SendAsync(buffer, socketFlags);
#endif
}

public static Task<int> ReceiveAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
#if NET452 || NET461
return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags));
#else
return socket.ReceiveAsync(buffer, socketFlags);
#endif
}

public static Task CompletedTask public static Task CompletedTask
{ {
get get


+ 14
- 14
Source/MQTTnet/Server/MqttClientConnection.cs View File

@@ -15,7 +15,7 @@ using System.Threading.Tasks;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
public class MqttClientConnection : IDisposable
public sealed class MqttClientConnection : IDisposable
{ {
private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
@@ -124,7 +124,7 @@ namespace MQTTnet.Server
return _packageReceiverTask; return _packageReceiverTask;
} }


private async Task<MqttClientDisconnectType> RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext)
async Task<MqttClientDisconnectType> RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext)
{ {
var disconnectType = MqttClientDisconnectType.NotClean; var disconnectType = MqttClientDisconnectType.NotClean;
try try
@@ -251,12 +251,12 @@ namespace MQTTnet.Server
return disconnectType; return disconnectType;
} }


private void StopInternal()
void StopInternal()
{ {
_cancellationToken.Cancel(false); _cancellationToken.Cancel(false);
} }


private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
{ {
var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false);
foreach (var applicationMessage in retainedMessages) foreach (var applicationMessage in retainedMessages)
@@ -265,7 +265,7 @@ namespace MQTTnet.Server
} }
} }


private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket)
async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket)
{ {
// TODO: Let the channel adapter create the packet. // TODO: Let the channel adapter create the packet.
var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false);
@@ -281,14 +281,14 @@ namespace MQTTnet.Server
await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false);
} }


private async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket)
async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket)
{ {
// TODO: Let the channel adapter create the packet. // TODO: Let the channel adapter create the packet.
var unsubscribeResult = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); var unsubscribeResult = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false);
await SendAsync(unsubscribeResult).ConfigureAwait(false); await SendAsync(unsubscribeResult).ConfigureAwait(false);
} }


private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket)
{ {
Interlocked.Increment(ref _sentApplicationMessagesCount); Interlocked.Increment(ref _sentApplicationMessagesCount);


@@ -313,7 +313,7 @@ namespace MQTTnet.Server
} }
} }


private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket)
{ {
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket);


@@ -322,7 +322,7 @@ namespace MQTTnet.Server
return Task.FromResult(0); return Task.FromResult(0);
} }


private Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket)
{ {
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket);
_sessionsManager.DispatchApplicationMessage(applicationMessage, this); _sessionsManager.DispatchApplicationMessage(applicationMessage, this);
@@ -331,7 +331,7 @@ namespace MQTTnet.Server
return SendAsync(pubAckPacket); return SendAsync(pubAckPacket);
} }


private Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket)
{ {
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket);
_sessionsManager.DispatchApplicationMessage(applicationMessage, this); _sessionsManager.DispatchApplicationMessage(applicationMessage, this);
@@ -345,7 +345,7 @@ namespace MQTTnet.Server
return SendAsync(pubRecPacket); return SendAsync(pubRecPacket);
} }


private async Task SendPendingPacketsAsync(CancellationToken cancellationToken)
async Task SendPendingPacketsAsync(CancellationToken cancellationToken)
{ {
MqttQueuedApplicationMessage queuedApplicationMessage = null; MqttQueuedApplicationMessage queuedApplicationMessage = null;
MqttPublishPacket publishPacket = null; MqttPublishPacket publishPacket = null;
@@ -459,7 +459,7 @@ namespace MQTTnet.Server
} }
} }


private async Task SendAsync(MqttBasePacket packet)
async Task SendAsync(MqttBasePacket packet)
{ {
await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, _cancellationToken.Token).ConfigureAwait(false); await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, _cancellationToken.Token).ConfigureAwait(false);


@@ -471,12 +471,12 @@ namespace MQTTnet.Server
} }
} }


private void OnAdapterReadingPacketCompleted()
void OnAdapterReadingPacketCompleted()
{ {
_keepAliveMonitor?.Resume(); _keepAliveMonitor?.Resume();
} }


private void OnAdapterReadingPacketStarted()
void OnAdapterReadingPacketStarted()
{ {
_keepAliveMonitor?.Pause(); _keepAliveMonitor?.Pause();
} }


+ 5
- 6
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -207,7 +207,7 @@ namespace MQTTnet.Server
applicationMessage = interceptorContext.ApplicationMessage; applicationMessage = interceptorContext.ApplicationMessage;
} }


await _eventDispatcher.HandleApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);


if (applicationMessage.Retain) if (applicationMessage.Retain)
{ {
@@ -237,7 +237,7 @@ namespace MQTTnet.Server
string clientId = null; string clientId = null;
var clientWasConnected = true; var clientWasConnected = true;


MqttConnectPacket connectPacket = null;
MqttConnectPacket connectPacket;


try try
{ {
@@ -259,8 +259,6 @@ namespace MQTTnet.Server


var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);


clientId = connectPacket.ClientId;

if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success)
{ {
clientWasConnected = false; clientWasConnected = false;
@@ -272,9 +270,10 @@ namespace MQTTnet.Server
return; return;
} }


clientId = connectPacket.ClientId;
var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false);


await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false);


disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false);
} }
@@ -303,7 +302,7 @@ namespace MQTTnet.Server


if (clientId != null) if (clientId != null)
{ {
await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
} }
} }
} }


+ 9
- 9
Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs View File

@@ -1,9 +1,9 @@
using System;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Packets;
using MQTTnet.Protocol;


namespace MQTTnet.Server namespace MQTTnet.Server
{ {
@@ -67,7 +67,7 @@ namespace MQTTnet.Server
_subscriptions[finalTopicFilter.Topic] = finalTopicFilter; _subscriptions[finalTopicFilter.Topic] = finalTopicFilter;
} }


await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false);
} }
} }


@@ -83,7 +83,7 @@ namespace MQTTnet.Server
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false);
if (!interceptorContext.AcceptSubscription) if (!interceptorContext.AcceptSubscription)
{ {
continue;
continue;
} }


if (interceptorContext.AcceptSubscription) if (interceptorContext.AcceptSubscription)
@@ -93,7 +93,7 @@ namespace MQTTnet.Server
_subscriptions[topicFilter.Topic] = topicFilter; _subscriptions[topicFilter.Topic] = topicFilter;
} }


await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
} }
} }
} }
@@ -131,9 +131,9 @@ namespace MQTTnet.Server


foreach (var topicFilter in unsubscribePacket.TopicFilters) foreach (var topicFilter in unsubscribePacket.TopicFilters)
{ {
await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
} }
return unsubAckPacket; return unsubAckPacket;
} }


@@ -152,7 +152,7 @@ namespace MQTTnet.Server
lock (_subscriptions) lock (_subscriptions)
{ {
_subscriptions.Remove(topicFilter); _subscriptions.Remove(topicFilter);
}
}
} }
} }




+ 55
- 27
Source/MQTTnet/Server/MqttServerEventDispatcher.cs View File

@@ -7,7 +7,7 @@ namespace MQTTnet.Server
{ {
public class MqttServerEventDispatcher public class MqttServerEventDispatcher
{ {
private readonly IMqttNetLogger _logger;
readonly IMqttNetLogger _logger;


public MqttServerEventDispatcher(IMqttNetLogger logger) public MqttServerEventDispatcher(IMqttNetLogger logger)
{ {
@@ -24,18 +24,25 @@ namespace MQTTnet.Server


public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; }


public Task HandleClientConnectedAsync(string clientId)
public async Task SafeNotifyClientConnectedAsync(string clientId)
{ {
var handler = ClientConnectedHandler;
if (handler == null)
try
{ {
return Task.FromResult(0);
}
var handler = ClientConnectedHandler;
if (handler == null)
{
return;
}


return handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId));
await handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId)).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ClientConnected' event.");
}
} }


public async Task TryHandleClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType)
public async Task SafeNotifyClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType)
{ {
try try
{ {
@@ -49,41 +56,62 @@ namespace MQTTnet.Server
} }
catch (Exception exception) catch (Exception exception)
{ {
_logger.Error(exception, "Error while handling 'ClientDisconnected' event.");
_logger.Error(exception, "Error while handling custom 'ClientDisconnected' event.");
} }
} }


public Task HandleClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter)
public async Task SafeNotifyClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter)
{ {
var handler = ClientSubscribedTopicHandler;
if (handler == null)
try
{ {
return Task.FromResult(0);
}
var handler = ClientSubscribedTopicHandler;
if (handler == null)
{
return;
}


return handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter));
await handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ClientSubscribedTopic' event.");
}
} }


public Task HandleClientUnsubscribedTopicAsync(string clientId, string topicFilter)
public async Task SafeNotifyClientUnsubscribedTopicAsync(string clientId, string topicFilter)
{ {
var handler = ClientUnsubscribedTopicHandler;
if (handler == null)
try
{ {
return Task.FromResult(0);
}
var handler = ClientUnsubscribedTopicHandler;
if (handler == null)
{
return;
}


return handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter));
await handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ClientUnsubscribedTopic' event.");
}
} }


public Task HandleApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage)
public async Task SafeNotifyApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage)
{ {
var handler = ApplicationMessageReceivedHandler;
if (handler == null)
try
{ {
return Task.FromResult(0);
}
var handler = ApplicationMessageReceivedHandler;
if (handler == null)
{
return;
}


return handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage));
await handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage)).ConfigureAwait(false); ;
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ApplicationMessageReceived' event.");
}
} }
} }
} }

+ 74
- 0
Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs View File

@@ -0,0 +1,74 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Implementations;
using System;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.Tests
{
[TestClass]
public class CrossPlatformSocket_Tests
{
[TestMethod]
public async Task Connect_Send_Receive()
{
var crossPlatformSocket = new CrossPlatformSocket();
await crossPlatformSocket.ConnectAsync("www.google.de", 80, CancellationToken.None);

var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n");
await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None);

var buffer = new byte[1024];
var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment<byte>(buffer), System.Net.Sockets.SocketFlags.None);
crossPlatformSocket.Dispose();

var responseText = Encoding.UTF8.GetString(buffer, 0, length);

Assert.IsTrue(responseText.Contains("HTTP/1.1 200 OK"));
}

[TestMethod]
public async Task Try_Connect_Invalid_Host()
{
var crossPlatformSocket = new CrossPlatformSocket();

var cancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(3));
cancellationToken.Token.Register(() => crossPlatformSocket.Dispose());

await crossPlatformSocket.ConnectAsync("www.google.de", 1234, CancellationToken.None);
}

//[TestMethod]
//public async Task Use_Disconnected_Socket()
//{
// var crossPlatformSocket = new CrossPlatformSocket();

// await crossPlatformSocket.ConnectAsync("www.google.de", 80);

// var requestBuffer = Encoding.UTF8.GetBytes("GET /wrong_uri HTTP/1.1\r\nConnection: close\r\n\r\n");
// await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None);

// var buffer = new byte[64000];
// var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment<byte>(buffer), System.Net.Sockets.SocketFlags.None);

// await Task.Delay(500);

// await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None);
//}

[TestMethod]
public async Task Set_Options()
{
var crossPlatformSocket = new CrossPlatformSocket();

Assert.IsFalse(crossPlatformSocket.ReuseAddress);
crossPlatformSocket.ReuseAddress = true;
Assert.IsTrue(crossPlatformSocket.ReuseAddress);

Assert.IsFalse(crossPlatformSocket.NoDelay);
crossPlatformSocket.NoDelay = true;
Assert.IsTrue(crossPlatformSocket.NoDelay);
}
}
}

+ 9
- 9
Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs View File

@@ -1,10 +1,10 @@
using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Implementations;
using System;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Implementations;


namespace MQTTnet.Tests namespace MQTTnet.Tests
{ {
@@ -15,7 +15,7 @@ namespace MQTTnet.Tests
public async Task Dispose_Channel_While_Used() public async Task Dispose_Channel_While_Used()
{ {
var ct = new CancellationTokenSource(); var ct = new CancellationTokenSource();
var serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork);


try try
{ {
@@ -28,18 +28,18 @@ namespace MQTTnet.Tests
{ {
while (!ct.IsCancellationRequested) while (!ct.IsCancellationRequested)
{ {
var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket);
var client = await serverSocket.AcceptAsync();
var data = new byte[] { 128 }; var data = new byte[] { 128 };
await PlatformAbstractionLayer.SendAsync(client, new ArraySegment<byte>(data), SocketFlags.None);
await client.SendAsync(new ArraySegment<byte>(data), SocketFlags.None);
} }
}, ct.Token); }, ct.Token);


var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001);
var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork);
await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None);


await Task.Delay(100, ct.Token); await Task.Delay(100, ct.Token);


var tcpChannel = new MqttTcpChannel(new NetworkStream(clientSocket, true), "test", null);
var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null);


var buffer = new byte[1]; var buffer = new byte[1];
await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token); await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token);


+ 17
- 16
Tests/MQTTnet.Core.Tests/Server_Tests.cs View File

@@ -904,14 +904,13 @@ namespace MQTTnet.Tests


await testEnvironment.StartServerAsync(serverOptions); await testEnvironment.StartServerAsync(serverOptions);



var connectingFailedException = await Assert.ThrowsExceptionAsync<MqttConnectingFailedException>(() => testEnvironment.ConnectClientAsync()); var connectingFailedException = await Assert.ThrowsExceptionAsync<MqttConnectingFailedException>(() => testEnvironment.ConnectClientAsync());
Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode); Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode);
} }
} }


Dictionary<string, bool> _connected;


private Dictionary<string, bool> _connected;
private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs) private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs)
{ {
if (_connected.ContainsKey(eventArgs.ClientId)) if (_connected.ContainsKey(eventArgs.ClientId))
@@ -919,6 +918,7 @@ namespace MQTTnet.Tests
eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword;
return; return;
} }

_connected[eventArgs.ClientId] = true; _connected[eventArgs.ClientId] = true;
eventArgs.ReasonCode = MqttConnectReasonCode.Success; eventArgs.ReasonCode = MqttConnectReasonCode.Success;
return; return;
@@ -1053,6 +1053,12 @@ namespace MQTTnet.Tests
// Connect client with same client ID. Should disconnect existing client. // Connect client with same client ID. Should disconnect existing client.
var c2 = await testEnvironment.ConnectClientAsync(clientOptionsBuilder); var c2 = await testEnvironment.ConnectClientAsync(clientOptionsBuilder);


await Task.Delay(500);

flow = string.Join(string.Empty, events);

Assert.AreEqual("cdc", flow);

c2.UseApplicationMessageReceivedHandler(_ => c2.UseApplicationMessageReceivedHandler(_ =>
{ {
lock (events) lock (events)
@@ -1061,15 +1067,10 @@ namespace MQTTnet.Tests
} }
}); });


c2.SubscribeAsync("topic").Wait();

await Task.Delay(500);

flow = string.Join(string.Empty, events);
Assert.AreEqual("cdc", flow);
await c2.SubscribeAsync("topic");


// r // r
c2.PublishAsync("topic").Wait();
await c2.PublishAsync("topic");


await Task.Delay(500); await Task.Delay(500);


@@ -1149,15 +1150,15 @@ namespace MQTTnet.Tests
{ {
await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1)));


var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort);
var client = new CrossPlatformSocket(AddressFamily.InterNetwork);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);


// Don't send anything. The server should close the connection. // Don't send anything. The server should close the connection.
await Task.Delay(TimeSpan.FromSeconds(3)); await Task.Delay(TimeSpan.FromSeconds(3));


try try
{ {
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
if (receivedBytes == 0) if (receivedBytes == 0)
{ {
return; return;
@@ -1180,17 +1181,17 @@ namespace MQTTnet.Tests


// Send an invalid packet and ensure that the server will close the connection and stay in a waiting state // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state
// forever. This is security related. // forever. This is security related.
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort);
var client = new CrossPlatformSocket(AddressFamily.InterNetwork);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);


var buffer = Encoding.UTF8.GetBytes("Garbage"); var buffer = Encoding.UTF8.GetBytes("Garbage");
client.Send(buffer, buffer.Length, SocketFlags.None);
await client.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None);


await Task.Delay(TimeSpan.FromSeconds(3)); await Task.Delay(TimeSpan.FromSeconds(3));


try try
{ {
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
if (receivedBytes == 0) if (receivedBytes == 0)
{ {
return; return;


Loading…
Cancel
Save