Browse Source

Improve socket handling.

release/3.x.x
Christian Kratky 4 years ago
parent
commit
ad128c7889
11 changed files with 438 additions and 194 deletions
  1. +229
    -0
      Source/MQTTnet/Implementations/CrossPlatformSocket.cs
  2. +18
    -20
      Source/MQTTnet/Implementations/MqttTcpChannel.cs
  3. +7
    -7
      Source/MQTTnet/Implementations/MqttTcpServerListener.cs
  4. +1
    -86
      Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs
  5. +14
    -14
      Source/MQTTnet/Server/MqttClientConnection.cs
  6. +5
    -6
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  7. +9
    -9
      Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs
  8. +55
    -27
      Source/MQTTnet/Server/MqttServerEventDispatcher.cs
  9. +74
    -0
      Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs
  10. +9
    -9
      Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs
  11. +17
    -16
      Tests/MQTTnet.Core.Tests/Server_Tests.cs

+ 229
- 0
Source/MQTTnet/Implementations/CrossPlatformSocket.cs View File

@@ -0,0 +1,229 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.Implementations
{
public sealed class CrossPlatformSocket : IDisposable
{
readonly Socket _socket;

public CrossPlatformSocket(AddressFamily addressFamily)
{
_socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp);
}

public CrossPlatformSocket()
{
// Having this contructor is important because avoiding the address family as parameter
// will make use of dual mode in the .net framework.
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}

public CrossPlatformSocket(Socket socket)
{
_socket = socket ?? throw new ArgumentNullException(nameof(socket));
}

public bool NoDelay
{
get
{
return (int)_socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay) > 0;
}

set
{
_socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, value ? 1 : 0);
}
}

public bool DualMode
{
get
{
return _socket.DualMode;
}

set
{
_socket.DualMode = value;
}
}

public int ReceiveBufferSize
{
get
{
return _socket.ReceiveBufferSize;
}

set
{
_socket.ReceiveBufferSize = value;
}
}

public int SendBufferSize
{
get
{
return _socket.SendBufferSize;
}

set
{
_socket.SendBufferSize = value;
}
}

public EndPoint RemoteEndPoint
{
get
{
return _socket.RemoteEndPoint;
}
}

public bool ReuseAddress
{
get
{
return (int)_socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress) != 0;
}

set
{
_socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, value ? 1 : 0);
}
}

public async Task<CrossPlatformSocket> AcceptAsync()
{
try
{
#if NET452 || NET461
var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false);
return new CrossPlatformSocket(clientSocket);
#else
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false);
return new CrossPlatformSocket(clientSocket);
#endif
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndAccept gets called by Task library but the socket is already disposed.
return null;
}
}

public void Bind(EndPoint localEndPoint)
{
if (localEndPoint is null) throw new ArgumentNullException(nameof(localEndPoint));

_socket.Bind(localEndPoint);
}

public void Listen(int connectionBacklog)
{
_socket.Listen(connectionBacklog);
}

public async Task ConnectAsync(string host, int port, CancellationToken cancellationToken)
{
if (host is null) throw new ArgumentNullException(nameof(host));

try
{
// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => _socket.Dispose()))
{
cancellationToken.ThrowIfCancellationRequested();

#if NET452 || NET461
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, host, port, null).ConfigureAwait(false);
#else
await _socket.ConnectAsync(host, port).ConfigureAwait(false);
#endif
}
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed.
}
}

public async Task SendAsync(ArraySegment<byte> buffer, SocketFlags socketFlags)
{
try
{
#if NET452 || NET461
await Task.Factory.FromAsync(SocketWrapper.BeginSend, _socket.EndSend, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false);
#else
await _socket.SendAsync(buffer, socketFlags).ConfigureAwait(false);
#endif
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed.
}
}

public async Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFlags)
{
try
{
#if NET452 || NET461
return await Task.Factory.FromAsync(SocketWrapper.BeginReceive, _socket.EndReceive, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false);
#else
return await _socket.ReceiveAsync(buffer, socketFlags).ConfigureAwait(false);
#endif
}
catch (ObjectDisposedException)
{
// This will happen when _socket.EndReceive gets called by Task library but the socket is already disposed.
return -1;
}
}

public NetworkStream GetStream()
{
return new NetworkStream(_socket, true);
}

public void Dispose()
{
_socket?.Dispose();
}

#if NET452 || NET461
class SocketWrapper
{
readonly Socket _socket;
readonly ArraySegment<byte> _buffer;
readonly SocketFlags _socketFlags;

public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
_socket = socket;
_buffer = buffer;
_socketFlags = socketFlags;
}

public static IAsyncResult BeginSend(AsyncCallback callback, object state)
{
var socketWrapper = (SocketWrapper)state;
return socketWrapper._socket.BeginSend(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state);
}

public static IAsyncResult BeginReceive(AsyncCallback callback, object state)
{
var socketWrapper = (SocketWrapper)state;
return socketWrapper._socket.BeginReceive(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state);
}
}
#endif
}
}

+ 18
- 20
Source/MQTTnet/Implementations/MqttTcpChannel.cs View File

@@ -46,15 +46,15 @@ namespace MQTTnet.Implementations

public async Task ConnectAsync(CancellationToken cancellationToken)
{
Socket socket;
CrossPlatformSocket socket;

if (_options.AddressFamily == AddressFamily.Unspecified)
{
socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
socket = new CrossPlatformSocket();
}
else
{
socket = new Socket(_options.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
socket = new CrossPlatformSocket(_options.AddressFamily);
}

socket.ReceiveBufferSize = _options.BufferSize;
@@ -69,20 +69,24 @@ namespace MQTTnet.Implementations
socket.DualMode = _options.DualMode.Value;
}

// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(() => socket.Dispose()))
{
await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false);
}
await socket.ConnectAsync(_options.Server, _options.GetPort(), cancellationToken).ConfigureAwait(false);

var networkStream = new NetworkStream(socket, true);
var networkStream = socket.GetStream();

if (_options.TlsOptions.UseTls)
{
var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback);
_stream = sslStream;
try
{
await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false);
}
catch
{
sslStream.Dispose();
throw;
}

await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false);
_stream = sslStream;
}
else
{
@@ -107,17 +111,14 @@ namespace MQTTnet.Implementations
// Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(Dispose))
{
if (cancellationToken.IsCancellationRequested)
{
return 0;
}
cancellationToken.ThrowIfCancellationRequested();

return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
}
catch (ObjectDisposedException)
{
return 0;
return -1;
}
catch (IOException exception)
{
@@ -139,10 +140,7 @@ namespace MQTTnet.Implementations
// Workaround for: https://github.com/dotnet/corefx/issues/24430
using (cancellationToken.Register(Dispose))
{
if (cancellationToken.IsCancellationRequested)
{
return;
}
cancellationToken.ThrowIfCancellationRequested();

await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}


+ 7
- 7
Source/MQTTnet/Implementations/MqttTcpServerListener.cs View File

@@ -23,7 +23,7 @@ namespace MQTTnet.Implementations
readonly MqttServerTlsTcpEndpointOptions _tlsOptions;
readonly X509Certificate2 _tlsCertificate;

private Socket _socket;
private CrossPlatformSocket _socket;
private IPEndPoint _localEndPoint;

public MqttTcpServerListener(
@@ -59,18 +59,18 @@ namespace MQTTnet.Implementations

_logger.Info($"Starting TCP listener for {_localEndPoint} TLS={_tlsCertificate != null}.");

_socket = new Socket(_addressFamily, SocketType.Stream, ProtocolType.Tcp);
_socket = new CrossPlatformSocket(_addressFamily);

// Usage of socket options is described here: https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.setsocketoption?view=netcore-2.2

if (_options.ReuseAddress)
{
_socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_socket.ReuseAddress = true;
}

if (_options.NoDelay)
{
_socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
_socket.NoDelay = true;
}

_socket.Bind(_localEndPoint);
@@ -107,7 +107,7 @@ namespace MQTTnet.Implementations
{
try
{
var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false);
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false);
if (clientSocket == null)
{
continue;
@@ -135,7 +135,7 @@ namespace MQTTnet.Implementations
}
}

async Task TryHandleClientConnectionAsync(Socket clientSocket)
async Task TryHandleClientConnectionAsync(CrossPlatformSocket clientSocket)
{
Stream stream = null;
string remoteEndPoint = null;
@@ -151,7 +151,7 @@ namespace MQTTnet.Implementations

clientSocket.NoDelay = _options.NoDelay;

stream = new NetworkStream(clientSocket, true);
stream = clientSocket.GetStream();

X509Certificate2 clientCertificate = null;



+ 1
- 86
Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs View File

@@ -1,94 +1,9 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;
using System.Threading.Tasks;

namespace MQTTnet.Implementations
{
public static class PlatformAbstractionLayer
{
// TODO: Consider creating primitives like "MqttNetSocket" which will wrap all required methods and do the platform stuff.
public static async Task<Socket> AcceptAsync(Socket socket)
{
#if NET452 || NET461
try
{
return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false);
}
catch (ObjectDisposedException)
{
return null;
}
#else
return await socket.AcceptAsync().ConfigureAwait(false);
#endif
}


public static Task ConnectAsync(Socket socket, IPAddress ip, int port)
{
#if NET452 || NET461
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null);
#else
return socket.ConnectAsync(ip, port);
#endif
}

public static Task ConnectAsync(Socket socket, string host, int port)
{
#if NET452 || NET461
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null);
#else
return socket.ConnectAsync(host, port);
#endif
}

#if NET452 || NET461
public class SocketWrapper
{
private readonly Socket _socket;
private readonly ArraySegment<byte> _buffer;
private readonly SocketFlags _socketFlags;

public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
_socket = socket;
_buffer = buffer;
_socketFlags = socketFlags;
}

public static IAsyncResult BeginSend(AsyncCallback callback, object state)
{
var real = (SocketWrapper)state;
return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state);
}

public static IAsyncResult BeginReceive(AsyncCallback callback, object state)
{
var real = (SocketWrapper)state;
return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state);
}
}
#endif

public static Task SendAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
#if NET452 || NET461
return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags));
#else
return socket.SendAsync(buffer, socketFlags);
#endif
}

public static Task<int> ReceiveAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
#if NET452 || NET461
return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags));
#else
return socket.ReceiveAsync(buffer, socketFlags);
#endif
}

public static Task CompletedTask
{
get


+ 14
- 14
Source/MQTTnet/Server/MqttClientConnection.cs View File

@@ -15,7 +15,7 @@ using System.Threading.Tasks;

namespace MQTTnet.Server
{
public class MqttClientConnection : IDisposable
public sealed class MqttClientConnection : IDisposable
{
private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
@@ -124,7 +124,7 @@ namespace MQTTnet.Server
return _packageReceiverTask;
}

private async Task<MqttClientDisconnectType> RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext)
async Task<MqttClientDisconnectType> RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext)
{
var disconnectType = MqttClientDisconnectType.NotClean;
try
@@ -251,12 +251,12 @@ namespace MQTTnet.Server
return disconnectType;
}

private void StopInternal()
void StopInternal()
{
_cancellationToken.Cancel(false);
}

private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters)
{
var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false);
foreach (var applicationMessage in retainedMessages)
@@ -265,7 +265,7 @@ namespace MQTTnet.Server
}
}

private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket)
async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket)
{
// TODO: Let the channel adapter create the packet.
var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false);
@@ -281,14 +281,14 @@ namespace MQTTnet.Server
await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false);
}

private async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket)
async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket)
{
// TODO: Let the channel adapter create the packet.
var unsubscribeResult = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false);
await SendAsync(unsubscribeResult).ConfigureAwait(false);
}

private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket)
{
Interlocked.Increment(ref _sentApplicationMessagesCount);

@@ -313,7 +313,7 @@ namespace MQTTnet.Server
}
}

private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket)
{
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket);

@@ -322,7 +322,7 @@ namespace MQTTnet.Server
return Task.FromResult(0);
}

private Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket)
{
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket);
_sessionsManager.DispatchApplicationMessage(applicationMessage, this);
@@ -331,7 +331,7 @@ namespace MQTTnet.Server
return SendAsync(pubAckPacket);
}

private Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket)
Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket)
{
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket);
_sessionsManager.DispatchApplicationMessage(applicationMessage, this);
@@ -345,7 +345,7 @@ namespace MQTTnet.Server
return SendAsync(pubRecPacket);
}

private async Task SendPendingPacketsAsync(CancellationToken cancellationToken)
async Task SendPendingPacketsAsync(CancellationToken cancellationToken)
{
MqttQueuedApplicationMessage queuedApplicationMessage = null;
MqttPublishPacket publishPacket = null;
@@ -459,7 +459,7 @@ namespace MQTTnet.Server
}
}

private async Task SendAsync(MqttBasePacket packet)
async Task SendAsync(MqttBasePacket packet)
{
await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, _cancellationToken.Token).ConfigureAwait(false);

@@ -471,12 +471,12 @@ namespace MQTTnet.Server
}
}

private void OnAdapterReadingPacketCompleted()
void OnAdapterReadingPacketCompleted()
{
_keepAliveMonitor?.Resume();
}

private void OnAdapterReadingPacketStarted()
void OnAdapterReadingPacketStarted()
{
_keepAliveMonitor?.Pause();
}


+ 5
- 6
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -207,7 +207,7 @@ namespace MQTTnet.Server
applicationMessage = interceptorContext.ApplicationMessage;
}

await _eventDispatcher.HandleApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);

if (applicationMessage.Retain)
{
@@ -237,7 +237,7 @@ namespace MQTTnet.Server
string clientId = null;
var clientWasConnected = true;

MqttConnectPacket connectPacket = null;
MqttConnectPacket connectPacket;

try
{
@@ -259,8 +259,6 @@ namespace MQTTnet.Server

var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);

clientId = connectPacket.ClientId;

if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success)
{
clientWasConnected = false;
@@ -272,9 +270,10 @@ namespace MQTTnet.Server
return;
}

clientId = connectPacket.ClientId;
var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false);

await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false);

disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false);
}
@@ -303,7 +302,7 @@ namespace MQTTnet.Server

if (clientId != null)
{
await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
}
}
}


+ 9
- 9
Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs View File

@@ -1,9 +1,9 @@
using System;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using MQTTnet.Packets;
using MQTTnet.Protocol;

namespace MQTTnet.Server
{
@@ -67,7 +67,7 @@ namespace MQTTnet.Server
_subscriptions[finalTopicFilter.Topic] = finalTopicFilter;
}

await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false);
}
}

@@ -83,7 +83,7 @@ namespace MQTTnet.Server
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false);
if (!interceptorContext.AcceptSubscription)
{
continue;
continue;
}

if (interceptorContext.AcceptSubscription)
@@ -93,7 +93,7 @@ namespace MQTTnet.Server
_subscriptions[topicFilter.Topic] = topicFilter;
}

await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
}
}
}
@@ -131,9 +131,9 @@ namespace MQTTnet.Server

foreach (var topicFilter in unsubscribePacket.TopicFilters)
{
await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
await _eventDispatcher.SafeNotifyClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
}
return unsubAckPacket;
}

@@ -152,7 +152,7 @@ namespace MQTTnet.Server
lock (_subscriptions)
{
_subscriptions.Remove(topicFilter);
}
}
}
}



+ 55
- 27
Source/MQTTnet/Server/MqttServerEventDispatcher.cs View File

@@ -7,7 +7,7 @@ namespace MQTTnet.Server
{
public class MqttServerEventDispatcher
{
private readonly IMqttNetLogger _logger;
readonly IMqttNetLogger _logger;

public MqttServerEventDispatcher(IMqttNetLogger logger)
{
@@ -24,18 +24,25 @@ namespace MQTTnet.Server

public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; }

public Task HandleClientConnectedAsync(string clientId)
public async Task SafeNotifyClientConnectedAsync(string clientId)
{
var handler = ClientConnectedHandler;
if (handler == null)
try
{
return Task.FromResult(0);
}
var handler = ClientConnectedHandler;
if (handler == null)
{
return;
}

return handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId));
await handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId)).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ClientConnected' event.");
}
}

public async Task TryHandleClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType)
public async Task SafeNotifyClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType)
{
try
{
@@ -49,41 +56,62 @@ namespace MQTTnet.Server
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling 'ClientDisconnected' event.");
_logger.Error(exception, "Error while handling custom 'ClientDisconnected' event.");
}
}

public Task HandleClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter)
public async Task SafeNotifyClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter)
{
var handler = ClientSubscribedTopicHandler;
if (handler == null)
try
{
return Task.FromResult(0);
}
var handler = ClientSubscribedTopicHandler;
if (handler == null)
{
return;
}

return handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter));
await handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ClientSubscribedTopic' event.");
}
}

public Task HandleClientUnsubscribedTopicAsync(string clientId, string topicFilter)
public async Task SafeNotifyClientUnsubscribedTopicAsync(string clientId, string topicFilter)
{
var handler = ClientUnsubscribedTopicHandler;
if (handler == null)
try
{
return Task.FromResult(0);
}
var handler = ClientUnsubscribedTopicHandler;
if (handler == null)
{
return;
}

return handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter));
await handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ClientUnsubscribedTopic' event.");
}
}

public Task HandleApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage)
public async Task SafeNotifyApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage)
{
var handler = ApplicationMessageReceivedHandler;
if (handler == null)
try
{
return Task.FromResult(0);
}
var handler = ApplicationMessageReceivedHandler;
if (handler == null)
{
return;
}

return handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage));
await handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage)).ConfigureAwait(false); ;
}
catch (Exception exception)
{
_logger.Error(exception, "Error while handling custom 'ApplicationMessageReceived' event.");
}
}
}
}

+ 74
- 0
Tests/MQTTnet.Core.Tests/CrossPlatformSocket_Tests.cs View File

@@ -0,0 +1,74 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Implementations;
using System;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.Tests
{
[TestClass]
public class CrossPlatformSocket_Tests
{
[TestMethod]
public async Task Connect_Send_Receive()
{
var crossPlatformSocket = new CrossPlatformSocket();
await crossPlatformSocket.ConnectAsync("www.google.de", 80, CancellationToken.None);

var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n");
await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None);

var buffer = new byte[1024];
var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment<byte>(buffer), System.Net.Sockets.SocketFlags.None);
crossPlatformSocket.Dispose();

var responseText = Encoding.UTF8.GetString(buffer, 0, length);

Assert.IsTrue(responseText.Contains("HTTP/1.1 200 OK"));
}

[TestMethod]
public async Task Try_Connect_Invalid_Host()
{
var crossPlatformSocket = new CrossPlatformSocket();

var cancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(3));
cancellationToken.Token.Register(() => crossPlatformSocket.Dispose());

await crossPlatformSocket.ConnectAsync("www.google.de", 1234, CancellationToken.None);
}

//[TestMethod]
//public async Task Use_Disconnected_Socket()
//{
// var crossPlatformSocket = new CrossPlatformSocket();

// await crossPlatformSocket.ConnectAsync("www.google.de", 80);

// var requestBuffer = Encoding.UTF8.GetBytes("GET /wrong_uri HTTP/1.1\r\nConnection: close\r\n\r\n");
// await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None);

// var buffer = new byte[64000];
// var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment<byte>(buffer), System.Net.Sockets.SocketFlags.None);

// await Task.Delay(500);

// await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None);
//}

[TestMethod]
public async Task Set_Options()
{
var crossPlatformSocket = new CrossPlatformSocket();

Assert.IsFalse(crossPlatformSocket.ReuseAddress);
crossPlatformSocket.ReuseAddress = true;
Assert.IsTrue(crossPlatformSocket.ReuseAddress);

Assert.IsFalse(crossPlatformSocket.NoDelay);
crossPlatformSocket.NoDelay = true;
Assert.IsTrue(crossPlatformSocket.NoDelay);
}
}
}

+ 9
- 9
Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs View File

@@ -1,10 +1,10 @@
using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Implementations;
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Implementations;

namespace MQTTnet.Tests
{
@@ -15,7 +15,7 @@ namespace MQTTnet.Tests
public async Task Dispose_Channel_While_Used()
{
var ct = new CancellationTokenSource();
var serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork);

try
{
@@ -28,18 +28,18 @@ namespace MQTTnet.Tests
{
while (!ct.IsCancellationRequested)
{
var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket);
var client = await serverSocket.AcceptAsync();
var data = new byte[] { 128 };
await PlatformAbstractionLayer.SendAsync(client, new ArraySegment<byte>(data), SocketFlags.None);
await client.SendAsync(new ArraySegment<byte>(data), SocketFlags.None);
}
}, ct.Token);

var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001);
var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork);
await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None);

await Task.Delay(100, ct.Token);

var tcpChannel = new MqttTcpChannel(new NetworkStream(clientSocket, true), "test", null);
var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null);

var buffer = new byte[1];
await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token);


+ 17
- 16
Tests/MQTTnet.Core.Tests/Server_Tests.cs View File

@@ -904,14 +904,13 @@ namespace MQTTnet.Tests

await testEnvironment.StartServerAsync(serverOptions);


var connectingFailedException = await Assert.ThrowsExceptionAsync<MqttConnectingFailedException>(() => testEnvironment.ConnectClientAsync());
Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode);
}
}

Dictionary<string, bool> _connected;

private Dictionary<string, bool> _connected;
private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs)
{
if (_connected.ContainsKey(eventArgs.ClientId))
@@ -919,6 +918,7 @@ namespace MQTTnet.Tests
eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword;
return;
}

_connected[eventArgs.ClientId] = true;
eventArgs.ReasonCode = MqttConnectReasonCode.Success;
return;
@@ -1053,6 +1053,12 @@ namespace MQTTnet.Tests
// Connect client with same client ID. Should disconnect existing client.
var c2 = await testEnvironment.ConnectClientAsync(clientOptionsBuilder);

await Task.Delay(500);

flow = string.Join(string.Empty, events);

Assert.AreEqual("cdc", flow);

c2.UseApplicationMessageReceivedHandler(_ =>
{
lock (events)
@@ -1061,15 +1067,10 @@ namespace MQTTnet.Tests
}
});

c2.SubscribeAsync("topic").Wait();

await Task.Delay(500);

flow = string.Join(string.Empty, events);
Assert.AreEqual("cdc", flow);
await c2.SubscribeAsync("topic");

// r
c2.PublishAsync("topic").Wait();
await c2.PublishAsync("topic");

await Task.Delay(500);

@@ -1149,15 +1150,15 @@ namespace MQTTnet.Tests
{
await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1)));

var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort);
var client = new CrossPlatformSocket(AddressFamily.InterNetwork);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);

// Don't send anything. The server should close the connection.
await Task.Delay(TimeSpan.FromSeconds(3));

try
{
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
if (receivedBytes == 0)
{
return;
@@ -1180,17 +1181,17 @@ namespace MQTTnet.Tests

// Send an invalid packet and ensure that the server will close the connection and stay in a waiting state
// forever. This is security related.
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort);
var client = new CrossPlatformSocket(AddressFamily.InterNetwork);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);

var buffer = Encoding.UTF8.GetBytes("Garbage");
client.Send(buffer, buffer.Length, SocketFlags.None);
await client.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None);

await Task.Delay(TimeSpan.FromSeconds(3));

try
{
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial);
if (receivedBytes == 0)
{
return;


Loading…
Cancel
Save