@@ -0,0 +1,229 @@ | |||
using System; | |||
using System.Net; | |||
using System.Net.Sockets; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
namespace MQTTnet.Implementations | |||
{ | |||
public sealed class CrossPlatformSocket : IDisposable | |||
{ | |||
readonly Socket _socket; | |||
public CrossPlatformSocket(AddressFamily addressFamily) | |||
{ | |||
_socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp); | |||
} | |||
public CrossPlatformSocket() | |||
{ | |||
// Having this contructor is important because avoiding the address family as parameter | |||
// will make use of dual mode in the .net framework. | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
} | |||
public CrossPlatformSocket(Socket socket) | |||
{ | |||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | |||
} | |||
public bool NoDelay | |||
{ | |||
get | |||
{ | |||
return (int)_socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay) > 0; | |||
} | |||
set | |||
{ | |||
_socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, value ? 1 : 0); | |||
} | |||
} | |||
public bool DualMode | |||
{ | |||
get | |||
{ | |||
return _socket.DualMode; | |||
} | |||
set | |||
{ | |||
_socket.DualMode = value; | |||
} | |||
} | |||
public int ReceiveBufferSize | |||
{ | |||
get | |||
{ | |||
return _socket.ReceiveBufferSize; | |||
} | |||
set | |||
{ | |||
_socket.ReceiveBufferSize = value; | |||
} | |||
} | |||
public int SendBufferSize | |||
{ | |||
get | |||
{ | |||
return _socket.SendBufferSize; | |||
} | |||
set | |||
{ | |||
_socket.SendBufferSize = value; | |||
} | |||
} | |||
public EndPoint RemoteEndPoint | |||
{ | |||
get | |||
{ | |||
return _socket.RemoteEndPoint; | |||
} | |||
} | |||
public bool ReuseAddress | |||
{ | |||
get | |||
{ | |||
return (int)_socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress) != 0; | |||
} | |||
set | |||
{ | |||
_socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, value ? 1 : 0); | |||
} | |||
} | |||
public async Task<CrossPlatformSocket> AcceptAsync() | |||
{ | |||
try | |||
{ | |||
#if NET452 || NET461 | |||
var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); | |||
return new CrossPlatformSocket(clientSocket); | |||
#else | |||
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); | |||
return new CrossPlatformSocket(clientSocket); | |||
#endif | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
// This will happen when _socket.EndAccept gets called by Task library but the socket is already disposed. | |||
return null; | |||
} | |||
} | |||
public void Bind(EndPoint localEndPoint) | |||
{ | |||
if (localEndPoint is null) throw new ArgumentNullException(nameof(localEndPoint)); | |||
_socket.Bind(localEndPoint); | |||
} | |||
public void Listen(int connectionBacklog) | |||
{ | |||
_socket.Listen(connectionBacklog); | |||
} | |||
public async Task ConnectAsync(string host, int port, CancellationToken cancellationToken) | |||
{ | |||
if (host is null) throw new ArgumentNullException(nameof(host)); | |||
try | |||
{ | |||
// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 | |||
using (cancellationToken.Register(() => _socket.Dispose())) | |||
{ | |||
cancellationToken.ThrowIfCancellationRequested(); | |||
#if NET452 || NET461 | |||
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, host, port, null).ConfigureAwait(false); | |||
#else | |||
await _socket.ConnectAsync(host, port).ConfigureAwait(false); | |||
#endif | |||
} | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
// This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed. | |||
} | |||
} | |||
public async Task SendAsync(ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
try | |||
{ | |||
#if NET452 || NET461 | |||
await Task.Factory.FromAsync(SocketWrapper.BeginSend, _socket.EndSend, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false); | |||
#else | |||
await _socket.SendAsync(buffer, socketFlags).ConfigureAwait(false); | |||
#endif | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
// This will happen when _socket.EndConnect gets called by Task library but the socket is already disposed. | |||
} | |||
} | |||
public async Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
try | |||
{ | |||
#if NET452 || NET461 | |||
return await Task.Factory.FromAsync(SocketWrapper.BeginReceive, _socket.EndReceive, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false); | |||
#else | |||
return await _socket.ReceiveAsync(buffer, socketFlags).ConfigureAwait(false); | |||
#endif | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
// This will happen when _socket.EndReceive gets called by Task library but the socket is already disposed. | |||
return -1; | |||
} | |||
} | |||
public NetworkStream GetStream() | |||
{ | |||
return new NetworkStream(_socket, true); | |||
} | |||
public void Dispose() | |||
{ | |||
_socket?.Dispose(); | |||
} | |||
#if NET452 || NET461 | |||
class SocketWrapper | |||
{ | |||
readonly Socket _socket; | |||
readonly ArraySegment<byte> _buffer; | |||
readonly SocketFlags _socketFlags; | |||
public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
_socket = socket; | |||
_buffer = buffer; | |||
_socketFlags = socketFlags; | |||
} | |||
public static IAsyncResult BeginSend(AsyncCallback callback, object state) | |||
{ | |||
var socketWrapper = (SocketWrapper)state; | |||
return socketWrapper._socket.BeginSend(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state); | |||
} | |||
public static IAsyncResult BeginReceive(AsyncCallback callback, object state) | |||
{ | |||
var socketWrapper = (SocketWrapper)state; | |||
return socketWrapper._socket.BeginReceive(socketWrapper._buffer.Array, socketWrapper._buffer.Offset, socketWrapper._buffer.Count, socketWrapper._socketFlags, callback, state); | |||
} | |||
} | |||
#endif | |||
} | |||
} |
@@ -46,15 +46,15 @@ namespace MQTTnet.Implementations | |||
public async Task ConnectAsync(CancellationToken cancellationToken) | |||
{ | |||
Socket socket; | |||
CrossPlatformSocket socket; | |||
if (_options.AddressFamily == AddressFamily.Unspecified) | |||
{ | |||
socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
socket = new CrossPlatformSocket(); | |||
} | |||
else | |||
{ | |||
socket = new Socket(_options.AddressFamily, SocketType.Stream, ProtocolType.Tcp); | |||
socket = new CrossPlatformSocket(_options.AddressFamily); | |||
} | |||
socket.ReceiveBufferSize = _options.BufferSize; | |||
@@ -69,20 +69,24 @@ namespace MQTTnet.Implementations | |||
socket.DualMode = _options.DualMode.Value; | |||
} | |||
// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 | |||
using (cancellationToken.Register(() => socket.Dispose())) | |||
{ | |||
await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false); | |||
} | |||
await socket.ConnectAsync(_options.Server, _options.GetPort(), cancellationToken).ConfigureAwait(false); | |||
var networkStream = new NetworkStream(socket, true); | |||
var networkStream = socket.GetStream(); | |||
if (_options.TlsOptions.UseTls) | |||
{ | |||
var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); | |||
_stream = sslStream; | |||
try | |||
{ | |||
await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); | |||
} | |||
catch | |||
{ | |||
sslStream.Dispose(); | |||
throw; | |||
} | |||
await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, !_options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); | |||
_stream = sslStream; | |||
} | |||
else | |||
{ | |||
@@ -107,17 +111,14 @@ namespace MQTTnet.Implementations | |||
// Workaround for: https://github.com/dotnet/corefx/issues/24430 | |||
using (cancellationToken.Register(Dispose)) | |||
{ | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return 0; | |||
} | |||
cancellationToken.ThrowIfCancellationRequested(); | |||
return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | |||
} | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
return 0; | |||
return -1; | |||
} | |||
catch (IOException exception) | |||
{ | |||
@@ -139,10 +140,7 @@ namespace MQTTnet.Implementations | |||
// Workaround for: https://github.com/dotnet/corefx/issues/24430 | |||
using (cancellationToken.Register(Dispose)) | |||
{ | |||
if (cancellationToken.IsCancellationRequested) | |||
{ | |||
return; | |||
} | |||
cancellationToken.ThrowIfCancellationRequested(); | |||
await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | |||
} | |||
@@ -23,7 +23,7 @@ namespace MQTTnet.Implementations | |||
readonly MqttServerTlsTcpEndpointOptions _tlsOptions; | |||
readonly X509Certificate2 _tlsCertificate; | |||
private Socket _socket; | |||
private CrossPlatformSocket _socket; | |||
private IPEndPoint _localEndPoint; | |||
public MqttTcpServerListener( | |||
@@ -59,18 +59,18 @@ namespace MQTTnet.Implementations | |||
_logger.Info($"Starting TCP listener for {_localEndPoint} TLS={_tlsCertificate != null}."); | |||
_socket = new Socket(_addressFamily, SocketType.Stream, ProtocolType.Tcp); | |||
_socket = new CrossPlatformSocket(_addressFamily); | |||
// Usage of socket options is described here: https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.setsocketoption?view=netcore-2.2 | |||
if (_options.ReuseAddress) | |||
{ | |||
_socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); | |||
_socket.ReuseAddress = true; | |||
} | |||
if (_options.NoDelay) | |||
{ | |||
_socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); | |||
_socket.NoDelay = true; | |||
} | |||
_socket.Bind(_localEndPoint); | |||
@@ -107,7 +107,7 @@ namespace MQTTnet.Implementations | |||
{ | |||
try | |||
{ | |||
var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false); | |||
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); | |||
if (clientSocket == null) | |||
{ | |||
continue; | |||
@@ -135,7 +135,7 @@ namespace MQTTnet.Implementations | |||
} | |||
} | |||
async Task TryHandleClientConnectionAsync(Socket clientSocket) | |||
async Task TryHandleClientConnectionAsync(CrossPlatformSocket clientSocket) | |||
{ | |||
Stream stream = null; | |||
string remoteEndPoint = null; | |||
@@ -151,7 +151,7 @@ namespace MQTTnet.Implementations | |||
clientSocket.NoDelay = _options.NoDelay; | |||
stream = new NetworkStream(clientSocket, true); | |||
stream = clientSocket.GetStream(); | |||
X509Certificate2 clientCertificate = null; | |||
@@ -1,94 +1,9 @@ | |||
using System; | |||
using System.Net; | |||
using System.Net.Sockets; | |||
using System.Threading.Tasks; | |||
using System.Threading.Tasks; | |||
namespace MQTTnet.Implementations | |||
{ | |||
public static class PlatformAbstractionLayer | |||
{ | |||
// TODO: Consider creating primitives like "MqttNetSocket" which will wrap all required methods and do the platform stuff. | |||
public static async Task<Socket> AcceptAsync(Socket socket) | |||
{ | |||
#if NET452 || NET461 | |||
try | |||
{ | |||
return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false); | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
return null; | |||
} | |||
#else | |||
return await socket.AcceptAsync().ConfigureAwait(false); | |||
#endif | |||
} | |||
public static Task ConnectAsync(Socket socket, IPAddress ip, int port) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null); | |||
#else | |||
return socket.ConnectAsync(ip, port); | |||
#endif | |||
} | |||
public static Task ConnectAsync(Socket socket, string host, int port) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null); | |||
#else | |||
return socket.ConnectAsync(host, port); | |||
#endif | |||
} | |||
#if NET452 || NET461 | |||
public class SocketWrapper | |||
{ | |||
private readonly Socket _socket; | |||
private readonly ArraySegment<byte> _buffer; | |||
private readonly SocketFlags _socketFlags; | |||
public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
_socket = socket; | |||
_buffer = buffer; | |||
_socketFlags = socketFlags; | |||
} | |||
public static IAsyncResult BeginSend(AsyncCallback callback, object state) | |||
{ | |||
var real = (SocketWrapper)state; | |||
return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); | |||
} | |||
public static IAsyncResult BeginReceive(AsyncCallback callback, object state) | |||
{ | |||
var real = (SocketWrapper)state; | |||
return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); | |||
} | |||
} | |||
#endif | |||
public static Task SendAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags)); | |||
#else | |||
return socket.SendAsync(buffer, socketFlags); | |||
#endif | |||
} | |||
public static Task<int> ReceiveAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags)); | |||
#else | |||
return socket.ReceiveAsync(buffer, socketFlags); | |||
#endif | |||
} | |||
public static Task CompletedTask | |||
{ | |||
get | |||
@@ -15,7 +15,7 @@ using System.Threading.Tasks; | |||
namespace MQTTnet.Server | |||
{ | |||
public class MqttClientConnection : IDisposable | |||
public sealed class MqttClientConnection : IDisposable | |||
{ | |||
private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); | |||
private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); | |||
@@ -124,7 +124,7 @@ namespace MQTTnet.Server | |||
return _packageReceiverTask; | |||
} | |||
private async Task<MqttClientDisconnectType> RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) | |||
async Task<MqttClientDisconnectType> RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) | |||
{ | |||
var disconnectType = MqttClientDisconnectType.NotClean; | |||
try | |||
@@ -251,12 +251,12 @@ namespace MQTTnet.Server | |||
return disconnectType; | |||
} | |||
private void StopInternal() | |||
void StopInternal() | |||
{ | |||
_cancellationToken.Cancel(false); | |||
} | |||
private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters) | |||
async Task EnqueueSubscribedRetainedMessagesAsync(ICollection<TopicFilter> topicFilters) | |||
{ | |||
var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); | |||
foreach (var applicationMessage in retainedMessages) | |||
@@ -265,7 +265,7 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) | |||
async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) | |||
{ | |||
// TODO: Let the channel adapter create the packet. | |||
var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); | |||
@@ -281,14 +281,14 @@ namespace MQTTnet.Server | |||
await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); | |||
} | |||
private async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket) | |||
async Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket) | |||
{ | |||
// TODO: Let the channel adapter create the packet. | |||
var unsubscribeResult = await Session.SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); | |||
await SendAsync(unsubscribeResult).ConfigureAwait(false); | |||
} | |||
private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) | |||
Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket) | |||
{ | |||
Interlocked.Increment(ref _sentApplicationMessagesCount); | |||
@@ -313,7 +313,7 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
private Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket) | |||
Task HandleIncomingPublishPacketWithQoS0Async(MqttPublishPacket publishPacket) | |||
{ | |||
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); | |||
@@ -322,7 +322,7 @@ namespace MQTTnet.Server | |||
return Task.FromResult(0); | |||
} | |||
private Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) | |||
Task HandleIncomingPublishPacketWithQoS1Async(MqttPublishPacket publishPacket) | |||
{ | |||
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); | |||
_sessionsManager.DispatchApplicationMessage(applicationMessage, this); | |||
@@ -331,7 +331,7 @@ namespace MQTTnet.Server | |||
return SendAsync(pubAckPacket); | |||
} | |||
private Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket) | |||
Task HandleIncomingPublishPacketWithQoS2Async(MqttPublishPacket publishPacket) | |||
{ | |||
var applicationMessage = _dataConverter.CreateApplicationMessage(publishPacket); | |||
_sessionsManager.DispatchApplicationMessage(applicationMessage, this); | |||
@@ -345,7 +345,7 @@ namespace MQTTnet.Server | |||
return SendAsync(pubRecPacket); | |||
} | |||
private async Task SendPendingPacketsAsync(CancellationToken cancellationToken) | |||
async Task SendPendingPacketsAsync(CancellationToken cancellationToken) | |||
{ | |||
MqttQueuedApplicationMessage queuedApplicationMessage = null; | |||
MqttPublishPacket publishPacket = null; | |||
@@ -459,7 +459,7 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
private async Task SendAsync(MqttBasePacket packet) | |||
async Task SendAsync(MqttBasePacket packet) | |||
{ | |||
await _channelAdapter.SendPacketAsync(packet, _serverOptions.DefaultCommunicationTimeout, _cancellationToken.Token).ConfigureAwait(false); | |||
@@ -471,12 +471,12 @@ namespace MQTTnet.Server | |||
} | |||
} | |||
private void OnAdapterReadingPacketCompleted() | |||
void OnAdapterReadingPacketCompleted() | |||
{ | |||
_keepAliveMonitor?.Resume(); | |||
} | |||
private void OnAdapterReadingPacketStarted() | |||
void OnAdapterReadingPacketStarted() | |||
{ | |||
_keepAliveMonitor?.Pause(); | |||
} | |||
@@ -207,7 +207,7 @@ namespace MQTTnet.Server | |||
applicationMessage = interceptorContext.ApplicationMessage; | |||
} | |||
await _eventDispatcher.HandleApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); | |||
await _eventDispatcher.SafeNotifyApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); | |||
if (applicationMessage.Retain) | |||
{ | |||
@@ -237,7 +237,7 @@ namespace MQTTnet.Server | |||
string clientId = null; | |||
var clientWasConnected = true; | |||
MqttConnectPacket connectPacket = null; | |||
MqttConnectPacket connectPacket; | |||
try | |||
{ | |||
@@ -259,8 +259,6 @@ namespace MQTTnet.Server | |||
var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); | |||
clientId = connectPacket.ClientId; | |||
if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) | |||
{ | |||
clientWasConnected = false; | |||
@@ -272,9 +270,10 @@ namespace MQTTnet.Server | |||
return; | |||
} | |||
clientId = connectPacket.ClientId; | |||
var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); | |||
await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); | |||
await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false); | |||
disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); | |||
} | |||
@@ -303,7 +302,7 @@ namespace MQTTnet.Server | |||
if (clientId != null) | |||
{ | |||
await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); | |||
await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); | |||
} | |||
} | |||
} | |||
@@ -1,9 +1,9 @@ | |||
using System; | |||
using MQTTnet.Packets; | |||
using MQTTnet.Protocol; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Packets; | |||
using MQTTnet.Protocol; | |||
namespace MQTTnet.Server | |||
{ | |||
@@ -67,7 +67,7 @@ namespace MQTTnet.Server | |||
_subscriptions[finalTopicFilter.Topic] = finalTopicFilter; | |||
} | |||
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false); | |||
await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false); | |||
} | |||
} | |||
@@ -83,7 +83,7 @@ namespace MQTTnet.Server | |||
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); | |||
if (!interceptorContext.AcceptSubscription) | |||
{ | |||
continue; | |||
continue; | |||
} | |||
if (interceptorContext.AcceptSubscription) | |||
@@ -93,7 +93,7 @@ namespace MQTTnet.Server | |||
_subscriptions[topicFilter.Topic] = topicFilter; | |||
} | |||
await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); | |||
await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); | |||
} | |||
} | |||
} | |||
@@ -131,9 +131,9 @@ namespace MQTTnet.Server | |||
foreach (var topicFilter in unsubscribePacket.TopicFilters) | |||
{ | |||
await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); | |||
await _eventDispatcher.SafeNotifyClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); | |||
} | |||
return unsubAckPacket; | |||
} | |||
@@ -152,7 +152,7 @@ namespace MQTTnet.Server | |||
lock (_subscriptions) | |||
{ | |||
_subscriptions.Remove(topicFilter); | |||
} | |||
} | |||
} | |||
} | |||
@@ -7,7 +7,7 @@ namespace MQTTnet.Server | |||
{ | |||
public class MqttServerEventDispatcher | |||
{ | |||
private readonly IMqttNetLogger _logger; | |||
readonly IMqttNetLogger _logger; | |||
public MqttServerEventDispatcher(IMqttNetLogger logger) | |||
{ | |||
@@ -24,18 +24,25 @@ namespace MQTTnet.Server | |||
public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get; set; } | |||
public Task HandleClientConnectedAsync(string clientId) | |||
public async Task SafeNotifyClientConnectedAsync(string clientId) | |||
{ | |||
var handler = ClientConnectedHandler; | |||
if (handler == null) | |||
try | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
var handler = ClientConnectedHandler; | |||
if (handler == null) | |||
{ | |||
return; | |||
} | |||
return handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId)); | |||
await handler.HandleClientConnectedAsync(new MqttServerClientConnectedEventArgs(clientId)).ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while handling custom 'ClientConnected' event."); | |||
} | |||
} | |||
public async Task TryHandleClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType) | |||
public async Task SafeNotifyClientDisconnectedAsync(string clientId, MqttClientDisconnectType disconnectType) | |||
{ | |||
try | |||
{ | |||
@@ -49,41 +56,62 @@ namespace MQTTnet.Server | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while handling 'ClientDisconnected' event."); | |||
_logger.Error(exception, "Error while handling custom 'ClientDisconnected' event."); | |||
} | |||
} | |||
public Task HandleClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter) | |||
public async Task SafeNotifyClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter) | |||
{ | |||
var handler = ClientSubscribedTopicHandler; | |||
if (handler == null) | |||
try | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
var handler = ClientSubscribedTopicHandler; | |||
if (handler == null) | |||
{ | |||
return; | |||
} | |||
return handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter)); | |||
await handler.HandleClientSubscribedTopicAsync(new MqttServerClientSubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while handling custom 'ClientSubscribedTopic' event."); | |||
} | |||
} | |||
public Task HandleClientUnsubscribedTopicAsync(string clientId, string topicFilter) | |||
public async Task SafeNotifyClientUnsubscribedTopicAsync(string clientId, string topicFilter) | |||
{ | |||
var handler = ClientUnsubscribedTopicHandler; | |||
if (handler == null) | |||
try | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
var handler = ClientUnsubscribedTopicHandler; | |||
if (handler == null) | |||
{ | |||
return; | |||
} | |||
return handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter)); | |||
await handler.HandleClientUnsubscribedTopicAsync(new MqttServerClientUnsubscribedTopicEventArgs(clientId, topicFilter)).ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while handling custom 'ClientUnsubscribedTopic' event."); | |||
} | |||
} | |||
public Task HandleApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage) | |||
public async Task SafeNotifyApplicationMessageReceivedAsync(string senderClientId, MqttApplicationMessage applicationMessage) | |||
{ | |||
var handler = ApplicationMessageReceivedHandler; | |||
if (handler == null) | |||
try | |||
{ | |||
return Task.FromResult(0); | |||
} | |||
var handler = ApplicationMessageReceivedHandler; | |||
if (handler == null) | |||
{ | |||
return; | |||
} | |||
return handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage)); | |||
await handler.HandleApplicationMessageReceivedAsync(new MqttApplicationMessageReceivedEventArgs(senderClientId, applicationMessage)).ConfigureAwait(false); ; | |||
} | |||
catch (Exception exception) | |||
{ | |||
_logger.Error(exception, "Error while handling custom 'ApplicationMessageReceived' event."); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,74 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Implementations; | |||
using System; | |||
using System.Text; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
namespace MQTTnet.Tests | |||
{ | |||
[TestClass] | |||
public class CrossPlatformSocket_Tests | |||
{ | |||
[TestMethod] | |||
public async Task Connect_Send_Receive() | |||
{ | |||
var crossPlatformSocket = new CrossPlatformSocket(); | |||
await crossPlatformSocket.ConnectAsync("www.google.de", 80, CancellationToken.None); | |||
var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n"); | |||
await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None); | |||
var buffer = new byte[1024]; | |||
var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment<byte>(buffer), System.Net.Sockets.SocketFlags.None); | |||
crossPlatformSocket.Dispose(); | |||
var responseText = Encoding.UTF8.GetString(buffer, 0, length); | |||
Assert.IsTrue(responseText.Contains("HTTP/1.1 200 OK")); | |||
} | |||
[TestMethod] | |||
public async Task Try_Connect_Invalid_Host() | |||
{ | |||
var crossPlatformSocket = new CrossPlatformSocket(); | |||
var cancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(3)); | |||
cancellationToken.Token.Register(() => crossPlatformSocket.Dispose()); | |||
await crossPlatformSocket.ConnectAsync("www.google.de", 1234, CancellationToken.None); | |||
} | |||
//[TestMethod] | |||
//public async Task Use_Disconnected_Socket() | |||
//{ | |||
// var crossPlatformSocket = new CrossPlatformSocket(); | |||
// await crossPlatformSocket.ConnectAsync("www.google.de", 80); | |||
// var requestBuffer = Encoding.UTF8.GetBytes("GET /wrong_uri HTTP/1.1\r\nConnection: close\r\n\r\n"); | |||
// await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None); | |||
// var buffer = new byte[64000]; | |||
// var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment<byte>(buffer), System.Net.Sockets.SocketFlags.None); | |||
// await Task.Delay(500); | |||
// await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None); | |||
//} | |||
[TestMethod] | |||
public async Task Set_Options() | |||
{ | |||
var crossPlatformSocket = new CrossPlatformSocket(); | |||
Assert.IsFalse(crossPlatformSocket.ReuseAddress); | |||
crossPlatformSocket.ReuseAddress = true; | |||
Assert.IsTrue(crossPlatformSocket.ReuseAddress); | |||
Assert.IsFalse(crossPlatformSocket.NoDelay); | |||
crossPlatformSocket.NoDelay = true; | |||
Assert.IsTrue(crossPlatformSocket.NoDelay); | |||
} | |||
} | |||
} |
@@ -1,10 +1,10 @@ | |||
using System; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Implementations; | |||
using System; | |||
using System.Net; | |||
using System.Net.Sockets; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Implementations; | |||
namespace MQTTnet.Tests | |||
{ | |||
@@ -15,7 +15,7 @@ namespace MQTTnet.Tests | |||
public async Task Dispose_Channel_While_Used() | |||
{ | |||
var ct = new CancellationTokenSource(); | |||
var serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |||
var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork); | |||
try | |||
{ | |||
@@ -28,18 +28,18 @@ namespace MQTTnet.Tests | |||
{ | |||
while (!ct.IsCancellationRequested) | |||
{ | |||
var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket); | |||
var client = await serverSocket.AcceptAsync(); | |||
var data = new byte[] { 128 }; | |||
await PlatformAbstractionLayer.SendAsync(client, new ArraySegment<byte>(data), SocketFlags.None); | |||
await client.SendAsync(new ArraySegment<byte>(data), SocketFlags.None); | |||
} | |||
}, ct.Token); | |||
var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |||
await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001); | |||
var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork); | |||
await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None); | |||
await Task.Delay(100, ct.Token); | |||
var tcpChannel = new MqttTcpChannel(new NetworkStream(clientSocket, true), "test", null); | |||
var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null); | |||
var buffer = new byte[1]; | |||
await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token); | |||
@@ -904,14 +904,13 @@ namespace MQTTnet.Tests | |||
await testEnvironment.StartServerAsync(serverOptions); | |||
var connectingFailedException = await Assert.ThrowsExceptionAsync<MqttConnectingFailedException>(() => testEnvironment.ConnectClientAsync()); | |||
Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode); | |||
} | |||
} | |||
Dictionary<string, bool> _connected; | |||
private Dictionary<string, bool> _connected; | |||
private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs) | |||
{ | |||
if (_connected.ContainsKey(eventArgs.ClientId)) | |||
@@ -919,6 +918,7 @@ namespace MQTTnet.Tests | |||
eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; | |||
return; | |||
} | |||
_connected[eventArgs.ClientId] = true; | |||
eventArgs.ReasonCode = MqttConnectReasonCode.Success; | |||
return; | |||
@@ -1053,6 +1053,12 @@ namespace MQTTnet.Tests | |||
// Connect client with same client ID. Should disconnect existing client. | |||
var c2 = await testEnvironment.ConnectClientAsync(clientOptionsBuilder); | |||
await Task.Delay(500); | |||
flow = string.Join(string.Empty, events); | |||
Assert.AreEqual("cdc", flow); | |||
c2.UseApplicationMessageReceivedHandler(_ => | |||
{ | |||
lock (events) | |||
@@ -1061,15 +1067,10 @@ namespace MQTTnet.Tests | |||
} | |||
}); | |||
c2.SubscribeAsync("topic").Wait(); | |||
await Task.Delay(500); | |||
flow = string.Join(string.Empty, events); | |||
Assert.AreEqual("cdc", flow); | |||
await c2.SubscribeAsync("topic"); | |||
// r | |||
c2.PublishAsync("topic").Wait(); | |||
await c2.PublishAsync("topic"); | |||
await Task.Delay(500); | |||
@@ -1149,15 +1150,15 @@ namespace MQTTnet.Tests | |||
{ | |||
await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); | |||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |||
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); | |||
var client = new CrossPlatformSocket(AddressFamily.InterNetwork); | |||
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); | |||
// Don't send anything. The server should close the connection. | |||
await Task.Delay(TimeSpan.FromSeconds(3)); | |||
try | |||
{ | |||
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
if (receivedBytes == 0) | |||
{ | |||
return; | |||
@@ -1180,17 +1181,17 @@ namespace MQTTnet.Tests | |||
// Send an invalid packet and ensure that the server will close the connection and stay in a waiting state | |||
// forever. This is security related. | |||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |||
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); | |||
var client = new CrossPlatformSocket(AddressFamily.InterNetwork); | |||
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); | |||
var buffer = Encoding.UTF8.GetBytes("Garbage"); | |||
client.Send(buffer, buffer.Length, SocketFlags.None); | |||
await client.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None); | |||
await Task.Delay(TimeSpan.FromSeconds(3)); | |||
try | |||
{ | |||
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
if (receivedBytes == 0) | |||
{ | |||
return; | |||