diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 963ee06..81e6caf 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -11,8 +11,15 @@ false MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker) and supports v3.1.0, v3.1.1 and v5.0.0 of the MQTT protocol. +* [ManagedClient] Added builder class for MqttClientUnsubscribeOptions (thanks to @dominikviererbe). * [ManagedClient] Added support for persisted sessions (thansk to @PMExtra). +* [Client] Improve connection stability (thanks to @jltjohanlindqvist). +* [ManagedClient] Fixed a memory leak (thanks to @zawodskoj). +* [ManagedClient] Improved internal subscription management (#569, thanks to @cstichlberger). +* [ManagedClient] Refactored log messages (thanks to @cstichlberger). * [Server] Added support for assigned client IDs (MQTTv5 only) (thanks to @bcrosnier). +* [Server] Added interceptor for unsubscriptions. +* [MQTTnet.Server] Added interceptor for unsubscriptions. Copyright Christian Kratky 2016-2019 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs index c9053e4..8b5a48f 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs @@ -16,15 +16,25 @@ using MQTTnet.Server; namespace MQTTnet.Extensions.ManagedClient { - public class ManagedMqttClient : IManagedMqttClient + public class ManagedMqttClient : Disposable, IManagedMqttClient { private readonly BlockingQueue _messageQueue = new BlockingQueue(); + + /// + /// The subscriptions are managed in 2 separate buckets: + /// and are processed during normal operation + /// and are moved to the when they get processed. They can be accessed by + /// any thread and are therefore mutex'ed. get sent to the broker + /// at reconnect and are solely owned by . + /// + private readonly Dictionary _reconnectSubscriptions = new Dictionary(); private readonly Dictionary _subscriptions = new Dictionary(); private readonly HashSet _unsubscriptions = new HashSet(); + private readonly SemaphoreSlim _subscriptionsQueuedSignal = new SemaphoreSlim(0); private readonly IMqttClient _mqttClient; private readonly IMqttNetChildLogger _logger; - + private readonly AsyncLock _messageQueueLock = new AsyncLock(); private CancellationTokenSource _connectionCancellationToken; @@ -32,10 +42,7 @@ namespace MQTTnet.Extensions.ManagedClient private Task _maintainConnectionTask; private ManagedMqttClientStorageManager _storageManager; - - private bool _disposed; - private bool _subscriptionsNotPushed; - + public ManagedMqttClient(IMqttClient mqttClient, IMqttNetChildLogger logger) { _mqttClient = mqttClient ?? throw new ArgumentNullException(nameof(mqttClient)); @@ -82,10 +89,6 @@ namespace MQTTnet.Extensions.ManagedClient if (options == null) throw new ArgumentNullException(nameof(options)); if (options.ClientOptions == null) throw new ArgumentException("The client options are not set.", nameof(options)); - if (!options.ClientOptions.CleanSession) - { - throw new NotSupportedException("The managed client does not support existing sessions."); - } if (!_maintainConnectionTask?.IsCompleted ?? false) throw new InvalidOperationException("The managed client is already started."); @@ -141,6 +144,7 @@ namespace MQTTnet.Extensions.ManagedClient ThrowIfDisposed(); if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); + if (Options == null) throw new InvalidOperationException("call StartAsync before publishing messages"); MqttTopicValidator.ThrowIfInvalid(applicationMessage.ApplicationMessage.Topic); @@ -169,7 +173,7 @@ namespace MQTTnet.Extensions.ManagedClient } _messageQueue.Enqueue(applicationMessage); - + if (_storageManager != null) { if (removedMessage != null) @@ -206,9 +210,10 @@ namespace MQTTnet.Extensions.ManagedClient foreach (var topicFilter in topicFilters) { _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; - _subscriptionsNotPushed = true; + _unsubscriptions.Remove(topicFilter.Topic); } } + _subscriptionsQueuedSignal.Release(); return Task.FromResult(0); } @@ -223,45 +228,34 @@ namespace MQTTnet.Extensions.ManagedClient { foreach (var topic in topics) { - if (_subscriptions.Remove(topic)) - { - _unsubscriptions.Add(topic); - _subscriptionsNotPushed = true; - } + _subscriptions.Remove(topic); + _unsubscriptions.Add(topic); } } + _subscriptionsQueuedSignal.Release(); return Task.FromResult(0); } - public void Dispose() + protected override void Dispose(bool disposing) { - if (_disposed) - { - return; - } - - _disposed = true; - - StopPublishing(); - StopMaintainingConnection(); - - if (_maintainConnectionTask != null) + if (disposing) { - Task.WaitAny(_maintainConnectionTask); - _maintainConnectionTask = null; - } + StopPublishing(); + StopMaintainingConnection(); - _messageQueueLock.Dispose(); - _mqttClient.Dispose(); - } + if (_maintainConnectionTask != null) + { + _maintainConnectionTask.GetAwaiter().GetResult(); + _maintainConnectionTask = null; + } - private void ThrowIfDisposed() - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(ManagedMqttClient)); + _messageQueue.Dispose(); + _messageQueueLock.Dispose(); + _mqttClient.Dispose(); + _subscriptionsQueuedSignal.Dispose(); } + base.Dispose(disposing); } private async Task MaintainConnectionAsync(CancellationToken cancellationToken) @@ -278,11 +272,11 @@ namespace MQTTnet.Extensions.ManagedClient } catch (Exception exception) { - _logger.Error(exception, "Unhandled exception while maintaining connection."); + _logger.Error(exception, "Error exception while maintaining connection."); } finally { - if (!_disposed) + if (!IsDisposed) { try { @@ -295,6 +289,12 @@ namespace MQTTnet.Extensions.ManagedClient _logger.Info("Stopped"); } + _reconnectSubscriptions.Clear(); + lock (_subscriptions) + { + _subscriptions.Clear(); + _unsubscriptions.Clear(); + } } } @@ -310,16 +310,22 @@ namespace MQTTnet.Extensions.ManagedClient return; } - if (connectionState == ReconnectionResult.Reconnected || _subscriptionsNotPushed) + if (connectionState == ReconnectionResult.Reconnected) + { + await PublishReconnectSubscriptionsAsync().ConfigureAwait(false); + StartPublishing(); + return; + } + + if (connectionState == ReconnectionResult.Recovered) { - await SynchronizeSubscriptionsAsync().ConfigureAwait(false); StartPublishing(); return; } if (connectionState == ReconnectionResult.StillConnected) { - await Task.Delay(Options.ConnectionCheckInterval, cancellationToken).ConfigureAwait(false); + await PublishSubscriptionsAsync(Options.ConnectionCheckInterval, cancellationToken).ConfigureAwait(false); } } catch (OperationCanceledException) @@ -327,11 +333,11 @@ namespace MQTTnet.Extensions.ManagedClient } catch (MqttCommunicationException exception) { - _logger.Warning(exception, "Communication exception while maintaining connection."); + _logger.Warning(exception, "Communication error while maintaining connection."); } catch (Exception exception) { - _logger.Error(exception, "Unhandled exception while maintaining connection."); + _logger.Error(exception, "Error exception while maintaining connection."); } } @@ -349,7 +355,7 @@ namespace MQTTnet.Extensions.ManagedClient // of the messages, the DropOldestQueuedMessage strategy would // be unable to know which message is actually the oldest and would // instead drop the first item in the queue. - var message = _messageQueue.PeekAndWait(); + var message = _messageQueue.PeekAndWait(cancellationToken); if (message == null) { continue; @@ -389,7 +395,7 @@ namespace MQTTnet.Extensions.ManagedClient // it from the queue. If not, that means this.PublishAsync has already // removed it, in which case we don't want to do anything. _messageQueue.RemoveFirst(i => i.Id.Equals(message.Id)); - + if (_storageManager != null) { await _storageManager.RemoveAsync(message).ConfigureAwait(false); @@ -414,7 +420,7 @@ namespace MQTTnet.Extensions.ManagedClient using (await _messageQueueLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) //lock to avoid conflict with this.PublishAsync { _messageQueue.RemoveFirst(i => i.Id.Equals(message.Id)); - + if (_storageManager != null) { await _storageManager.RemoveAsync(message).ConfigureAwait(false); @@ -438,50 +444,84 @@ namespace MQTTnet.Extensions.ManagedClient } } - private async Task SynchronizeSubscriptionsAsync() + private async Task PublishSubscriptionsAsync(TimeSpan timeout, CancellationToken cancellationToken) { - _logger.Info("Synchronizing subscriptions"); + var endTime = DateTime.UtcNow + timeout; + while (await _subscriptionsQueuedSignal.WaitAsync(GetRemainingTime(endTime), cancellationToken).ConfigureAwait(false)) + { + List subscriptions; + HashSet unsubscriptions; - List subscriptions; - HashSet unsubscriptions; + lock (_subscriptions) + { + subscriptions = _subscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }).ToList(); + _subscriptions.Clear(); + unsubscriptions = new HashSet(_unsubscriptions); + _unsubscriptions.Clear(); + } - lock (_subscriptions) - { - subscriptions = _subscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }).ToList(); + if (!subscriptions.Any() && !unsubscriptions.Any()) + { + continue; + } - unsubscriptions = new HashSet(_unsubscriptions); - _unsubscriptions.Clear(); + _logger.Verbose($"Publishing subscriptions ({subscriptions.Count} subscriptions and {unsubscriptions.Count} unsubscriptions)"); - _subscriptionsNotPushed = false; - } + foreach (var unsubscription in unsubscriptions) + { + _reconnectSubscriptions.Remove(unsubscription); + } - if (!subscriptions.Any() && !unsubscriptions.Any()) - { - return; - } + foreach (var subscription in subscriptions) + { + _reconnectSubscriptions[subscription.Topic] = subscription.QualityOfServiceLevel; + } - try - { - if (unsubscriptions.Any()) + try + { + if (unsubscriptions.Any()) + { + await _mqttClient.UnsubscribeAsync(unsubscriptions.ToArray()).ConfigureAwait(false); + } + + if (subscriptions.Any()) + { + await _mqttClient.SubscribeAsync(subscriptions.ToArray()).ConfigureAwait(false); + } + } + catch (Exception exception) { - await _mqttClient.UnsubscribeAsync(unsubscriptions.ToArray()).ConfigureAwait(false); + await HandleSubscriptionExceptionAsync(exception).ConfigureAwait(false); } + } + } + + private async Task PublishReconnectSubscriptionsAsync() + { + _logger.Info("Publishing subscriptions at reconnect"); - if (subscriptions.Any()) + try + { + if (_reconnectSubscriptions.Any()) { + var subscriptions = _reconnectSubscriptions.Select(i => new TopicFilter { Topic = i.Key, QualityOfServiceLevel = i.Value }); await _mqttClient.SubscribeAsync(subscriptions.ToArray()).ConfigureAwait(false); } } catch (Exception exception) { - _logger.Warning(exception, "Synchronizing subscriptions failed."); - _subscriptionsNotPushed = true; + await HandleSubscriptionExceptionAsync(exception).ConfigureAwait(false); + } + } - var synchronizingSubscriptionsFailedHandler = SynchronizingSubscriptionsFailedHandler; - if (SynchronizingSubscriptionsFailedHandler != null) - { - await synchronizingSubscriptionsFailedHandler.HandleSynchronizingSubscriptionsFailedAsync(new ManagedProcessFailedEventArgs(exception)).ConfigureAwait(false); - } + private async Task HandleSubscriptionExceptionAsync(Exception exception) + { + _logger.Warning(exception, "Synchronizing subscriptions failed."); + + var synchronizingSubscriptionsFailedHandler = SynchronizingSubscriptionsFailedHandler; + if (SynchronizingSubscriptionsFailedHandler != null) + { + await synchronizingSubscriptionsFailedHandler.HandleSynchronizingSubscriptionsFailedAsync(new ManagedProcessFailedEventArgs(exception)).ConfigureAwait(false); } } @@ -494,8 +534,8 @@ namespace MQTTnet.Extensions.ManagedClient try { - await _mqttClient.ConnectAsync(Options.ClientOptions).ConfigureAwait(false); - return ReconnectionResult.Reconnected; + var result = await _mqttClient.ConnectAsync(Options.ClientOptions).ConfigureAwait(false); + return result.IsSessionPresent ? ReconnectionResult.Recovered : ReconnectionResult.Reconnected; } catch (Exception exception) { @@ -508,7 +548,7 @@ namespace MQTTnet.Extensions.ManagedClient return ReconnectionResult.NotConnected; } } - + private void StartPublishing() { if (_publishingCancellationToken != null) @@ -535,5 +575,11 @@ namespace MQTTnet.Extensions.ManagedClient _connectionCancellationToken?.Dispose(); _connectionCancellationToken = null; } + + private TimeSpan GetRemainingTime(DateTime endTime) + { + var remainingTime = endTime - DateTime.UtcNow; + return remainingTime < TimeSpan.Zero ? TimeSpan.Zero : remainingTime; + } } } diff --git a/Source/MQTTnet.Extensions.ManagedClient/ReconnectionResult.cs b/Source/MQTTnet.Extensions.ManagedClient/ReconnectionResult.cs index fa876c3..092662f 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ReconnectionResult.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ReconnectionResult.cs @@ -4,6 +4,7 @@ { StillConnected, Reconnected, + Recovered, NotConnected } } diff --git a/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs b/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs index 68b12ef..47d9682 100644 --- a/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs +++ b/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs @@ -85,7 +85,12 @@ namespace MQTTnet.Extensions.WebSocket4Net { foreach (var certificate in _webSocketOptions.TlsOptions.Certificates) { +#if WINDOWS_UWP certificates.Add(new X509Certificate(certificate)); +#else + certificates.Add(certificate); +#endif + } } diff --git a/Source/MQTTnet.Server/Mqtt/MqttServerService.cs b/Source/MQTTnet.Server/Mqtt/MqttServerService.cs index b8c463f..85c4176 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttServerService.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttServerService.cs @@ -33,6 +33,7 @@ namespace MQTTnet.Server.Mqtt private readonly MqttServerConnectionValidator _mqttConnectionValidator; private readonly IMqttServer _mqttServer; private readonly MqttSubscriptionInterceptor _mqttSubscriptionInterceptor; + private readonly MqttUnsubscriptionInterceptor _mqttUnsubscriptionInterceptor; private readonly PythonScriptHostService _pythonScriptHostService; private readonly MqttWebSocketServerAdapter _webSocketServerAdapter; @@ -45,6 +46,7 @@ namespace MQTTnet.Server.Mqtt MqttClientUnsubscribedTopicHandler mqttClientUnsubscribedTopicHandler, MqttServerConnectionValidator mqttConnectionValidator, MqttSubscriptionInterceptor mqttSubscriptionInterceptor, + MqttUnsubscriptionInterceptor mqttUnsubscriptionInterceptor, MqttApplicationMessageInterceptor mqttApplicationMessageInterceptor, MqttServerStorage mqttServerStorage, PythonScriptHostService pythonScriptHostService, @@ -57,6 +59,7 @@ namespace MQTTnet.Server.Mqtt _mqttClientUnsubscribedTopicHandler = mqttClientUnsubscribedTopicHandler ?? throw new ArgumentNullException(nameof(mqttClientUnsubscribedTopicHandler)); _mqttConnectionValidator = mqttConnectionValidator ?? throw new ArgumentNullException(nameof(mqttConnectionValidator)); _mqttSubscriptionInterceptor = mqttSubscriptionInterceptor ?? throw new ArgumentNullException(nameof(mqttSubscriptionInterceptor)); + _mqttUnsubscriptionInterceptor = mqttUnsubscriptionInterceptor ?? throw new ArgumentNullException(nameof(mqttUnsubscriptionInterceptor)); _mqttApplicationMessageInterceptor = mqttApplicationMessageInterceptor ?? throw new ArgumentNullException(nameof(mqttApplicationMessageInterceptor)); _mqttServerStorage = mqttServerStorage ?? throw new ArgumentNullException(nameof(mqttServerStorage)); _pythonScriptHostService = pythonScriptHostService ?? throw new ArgumentNullException(nameof(pythonScriptHostService)); @@ -178,6 +181,7 @@ namespace MQTTnet.Server.Mqtt .WithConnectionValidator(_mqttConnectionValidator) .WithApplicationMessageInterceptor(_mqttApplicationMessageInterceptor) .WithSubscriptionInterceptor(_mqttSubscriptionInterceptor) + .WithUnsubscriptionInterceptor(_mqttUnsubscriptionInterceptor) .WithStorage(_mqttServerStorage); // Configure unencrypted connections diff --git a/Source/MQTTnet.Server/Mqtt/MqttUnsubscriptionInterceptor.cs b/Source/MQTTnet.Server/Mqtt/MqttUnsubscriptionInterceptor.cs new file mode 100644 index 0000000..1a460af --- /dev/null +++ b/Source/MQTTnet.Server/Mqtt/MqttUnsubscriptionInterceptor.cs @@ -0,0 +1,48 @@ +using System; +using System.Threading.Tasks; +using IronPython.Runtime; +using Microsoft.Extensions.Logging; +using MQTTnet.Server.Scripting; + +namespace MQTTnet.Server.Mqtt +{ + public class MqttUnsubscriptionInterceptor : IMqttServerUnsubscriptionInterceptor + { + private readonly PythonScriptHostService _pythonScriptHostService; + private readonly ILogger _logger; + + public MqttUnsubscriptionInterceptor(PythonScriptHostService pythonScriptHostService, ILogger logger) + { + _pythonScriptHostService = pythonScriptHostService ?? throw new ArgumentNullException(nameof(pythonScriptHostService)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + public Task InterceptUnsubscriptionAsync(MqttUnsubscriptionInterceptorContext context) + { + try + { + var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; + + var pythonContext = new PythonDictionary + { + { "client_id", context.ClientId }, + { "session_items", sessionItems }, + { "topic", context.Topic }, + { "accept_unsubscription", context.AcceptUnsubscription }, + { "close_connection", context.CloseConnection } + }; + + _pythonScriptHostService.InvokeOptionalFunction("on_intercept_unsubscription", pythonContext); + + context.AcceptUnsubscription = (bool)pythonContext["accept_unsubscription"]; + context.CloseConnection = (bool)pythonContext["close_connection"]; + } + catch (Exception exception) + { + _logger.LogError(exception, "Error while intercepting unsubscription."); + } + + return Task.CompletedTask; + } + } +} diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 4a0a85d..f364168 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -14,7 +14,7 @@ using MQTTnet.Packets; namespace MQTTnet.Adapter { - public class MqttChannelAdapter : IMqttChannelAdapter + public class MqttChannelAdapter : Disposable, IMqttChannelAdapter { private const uint ErrorOperationAborted = 0x800703E3; private const int ReadBufferSize = 4096; // TODO: Move buffer size to config @@ -26,9 +26,7 @@ namespace MQTTnet.Adapter private readonly MqttPacketReader _packetReader; private readonly byte[] _fixedHeaderBuffer = new byte[2]; - - private bool _isDisposed; - + private long _bytesReceived; private long _bytesSent; @@ -269,19 +267,13 @@ namespace MQTTnet.Adapter } } - public void Dispose() - { - _isDisposed = true; - - _channel?.Dispose(); - } - - private void ThrowIfDisposed() + protected override void Dispose(bool disposing) { - if (_isDisposed) + if (disposing) { - throw new ObjectDisposedException(nameof(MqttChannelAdapter)); + _channel?.Dispose(); } + base.Dispose(disposing); } private static bool IsWrappedException(Exception exception) diff --git a/Source/MQTTnet/Adapter/MqttConnectingFailedException.cs b/Source/MQTTnet/Adapter/MqttConnectingFailedException.cs index 44d50ec..ab49d93 100644 --- a/Source/MQTTnet/Adapter/MqttConnectingFailedException.cs +++ b/Source/MQTTnet/Adapter/MqttConnectingFailedException.cs @@ -5,12 +5,13 @@ namespace MQTTnet.Adapter { public class MqttConnectingFailedException : MqttCommunicationException { - public MqttConnectingFailedException(MqttClientConnectResultCode resultCode) - : base($"Connecting with MQTT server failed ({resultCode.ToString()}).") + public MqttConnectingFailedException(MqttClientAuthenticateResult result) + : base($"Connecting with MQTT server failed ({result.ResultCode.ToString()}).") { - ResultCode = resultCode; + Result = result; } - public MqttClientConnectResultCode ResultCode { get; } + public MqttClientAuthenticateResult Result { get; } + public MqttClientConnectResultCode ResultCode => Result.ResultCode; } } diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 27b56ff..29687fc 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; @@ -20,11 +20,12 @@ using MQTTnet.Protocol; namespace MQTTnet.Client { - public class MqttClient : IMqttClient + public class MqttClient : Disposable, IMqttClient { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly Stopwatch _sendTracker = new Stopwatch(); + private readonly Stopwatch _receiveTracker = new Stopwatch(); private readonly object _disconnectLock = new object(); private readonly IMqttClientAdapterFactory _adapterFactory; @@ -63,6 +64,8 @@ namespace MQTTnet.Client ThrowIfConnected("It is not allowed to connect with a server after the connection is established."); + ThrowIfDisposed(); + MqttClientAuthenticateResult authenticateResult = null; try @@ -79,15 +82,19 @@ namespace MQTTnet.Client var adapter = _adapterFactory.CreateClientAdapter(options, _logger); _adapter = adapter; - _logger.Verbose($"Trying to connect with server '{options.ChannelOptions}' (Timeout={options.CommunicationTimeout})."); - await _adapter.ConnectAsync(options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); - _logger.Verbose("Connection with server established."); + using (var combined = CancellationTokenSource.CreateLinkedTokenSource(backgroundCancellationToken, cancellationToken)) + { + _logger.Verbose($"Trying to connect with server '{options.ChannelOptions}' (Timeout={options.CommunicationTimeout})."); + await _adapter.ConnectAsync(options.CommunicationTimeout, combined.Token).ConfigureAwait(false); + _logger.Verbose("Connection with server established."); - _packetReceiverTask = Task.Run(() => TryReceivePacketsAsync(backgroundCancellationToken), backgroundCancellationToken); + _packetReceiverTask = Task.Run(() => TryReceivePacketsAsync(backgroundCancellationToken), backgroundCancellationToken); - authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, cancellationToken).ConfigureAwait(false); + authenticateResult = await AuthenticateAsync(adapter, options.WillMessage, combined.Token).ConfigureAwait(false); + } _sendTracker.Restart(); + _receiveTracker.Restart(); if (Options.KeepAlivePeriod != TimeSpan.Zero) { @@ -149,7 +156,7 @@ namespace MQTTnet.Client Properties = new MqttAuthPacketProperties { // This must always be equal to the value from the CONNECT packet. So we use it here to ensure that. - AuthenticationMethod = Options.AuthenticationMethod, + AuthenticationMethod = Options.AuthenticationMethod, AuthenticationData = data.AuthenticationData, ReasonString = data.ReasonString, UserProperties = data.UserProperties @@ -161,6 +168,7 @@ namespace MQTTnet.Client { if (options == null) throw new ArgumentNullException(nameof(options)); + ThrowIfDisposed(); ThrowIfNotConnected(); var subscribePacket = _adapter.PacketFormatterAdapter.DataConverter.CreateSubscribePacket(options); @@ -174,6 +182,7 @@ namespace MQTTnet.Client { if (options == null) throw new ArgumentNullException(nameof(options)); + ThrowIfDisposed(); ThrowIfNotConnected(); var unsubscribePacket = _adapter.PacketFormatterAdapter.DataConverter.CreateUnsubscribePacket(options); @@ -189,6 +198,7 @@ namespace MQTTnet.Client MqttTopicValidator.ThrowIfInvalid(applicationMessage.Topic); + ThrowIfDisposed(); ThrowIfNotConnected(); var publishPacket = _adapter.PacketFormatterAdapter.DataConverter.CreatePublishPacket(applicationMessage); @@ -214,7 +224,7 @@ namespace MQTTnet.Client } } - public void Dispose() + private void Cleanup() { _backgroundCancellationTokenSource?.Cancel(false); _backgroundCancellationTokenSource?.Dispose(); @@ -224,6 +234,18 @@ namespace MQTTnet.Client _adapter = null; } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + Cleanup(); + + DisconnectedHandler = null; + } + base.Dispose(disposing); + } + private async Task AuthenticateAsync(IMqttChannelAdapter channelAdapter, MqttApplicationMessage willApplicationMessage, CancellationToken cancellationToken) { var connectPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnectPacket( @@ -235,7 +257,7 @@ namespace MQTTnet.Client if (result.ResultCode != MqttClientConnectResultCode.Success) { - throw new MqttConnectingFailedException(result.ResultCode); + throw new MqttConnectingFailedException(result); } _logger.Verbose("Authenticated MQTT connection with server established."); @@ -258,29 +280,37 @@ namespace MQTTnet.Client var clientWasConnected = IsConnected; TryInitiateDisconnect(); + IsConnected = false; try { - IsConnected = false; - if (_adapter != null) { _logger.Verbose("Disconnecting [Timeout={0}]", Options.CommunicationTimeout); await _adapter.DisconnectAsync(Options.CommunicationTimeout, CancellationToken.None).ConfigureAwait(false); } - await WaitForTaskAsync(_packetReceiverTask, sender).ConfigureAwait(false); - await WaitForTaskAsync(_keepAlivePacketsSenderTask, sender).ConfigureAwait(false); - _logger.Verbose("Disconnected from adapter."); } catch (Exception adapterException) { _logger.Warning(adapterException, "Error while disconnecting from adapter."); } + + try + { + var receiverTask = WaitForTaskAsync(_packetReceiverTask, sender); + var keepAliveTask = WaitForTaskAsync(_keepAlivePacketsSenderTask, sender); + + await Task.WhenAll(receiverTask, keepAliveTask).ConfigureAwait(false); + } + catch (Exception e) + { + _logger.Warning(e, "Error while waiting for internal tasks."); + } finally { - Dispose(); + Cleanup(); _cleanDisconnectInitiated = false; _logger.Info("Disconnected."); @@ -344,11 +374,26 @@ namespace MQTTnet.Client try { await _adapter.SendPacketAsync(requestPacket, Options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); - return await packetAwaiter.WaitOneAsync(Options.CommunicationTimeout).ConfigureAwait(false); } - catch (MqttCommunicationTimedOutException) + catch (Exception e) + { + _logger.Warning(e, "Error when sending packet of type '{0}'.", typeof(TResponsePacket).Name); + packetAwaiter.Cancel(); + } + + try { - _logger.Warning(null, "Timeout while waiting for packet of type '{0}'.", typeof(TResponsePacket).Name); + var response = await packetAwaiter.WaitOneAsync(Options.CommunicationTimeout).ConfigureAwait(false); + _receiveTracker.Restart(); + return response; + } + catch (Exception exception) + { + if (exception is MqttCommunicationTimedOutException) + { + _logger.Warning(null, "Timeout while waiting for packet of type '{0}'.", typeof(TResponsePacket).Name); + } + throw; } } @@ -369,14 +414,14 @@ namespace MQTTnet.Client keepAliveSendInterval = Options.KeepAliveSendInterval.Value; } - var waitTime = keepAliveSendInterval - _sendTracker.Elapsed; - if (waitTime <= TimeSpan.Zero) + var waitTimeSend = keepAliveSendInterval - _sendTracker.Elapsed; + var waitTimeReceive = keepAliveSendInterval - _receiveTracker.Elapsed; + if (waitTimeSend <= TimeSpan.Zero || waitTimeReceive <= TimeSpan.Zero) { await SendAndReceiveAsync(new MqttPingReqPacket(), cancellationToken).ConfigureAwait(false); - waitTime = keepAliveSendInterval; } - await Task.Delay(waitTime, cancellationToken).ConfigureAwait(false); + await Task.Delay(keepAliveSendInterval, cancellationToken).ConfigureAwait(false); } } catch (Exception exception) @@ -391,11 +436,11 @@ namespace MQTTnet.Client } else if (exception is MqttCommunicationException) { - _logger.Warning(exception, "MQTT communication exception while sending/receiving keep alive packets."); + _logger.Warning(exception, "Communication error while sending/receiving keep alive packets."); } else { - _logger.Error(exception, "Unhandled exception while sending/receiving keep alive packets."); + _logger.Error(exception, "Error exception while sending/receiving keep alive packets."); } if (!DisconnectIsPending()) @@ -449,11 +494,11 @@ namespace MQTTnet.Client } else if (exception is MqttCommunicationException) { - _logger.Warning(exception, "MQTT communication exception while receiving packets."); + _logger.Warning(exception, "Communication error while receiving packets."); } else { - _logger.Error(exception, "Unhandled exception while receiving packets."); + _logger.Error(exception, "Error while receiving packets."); } _packetDispatcher.Dispatch(exception); @@ -473,6 +518,8 @@ namespace MQTTnet.Client { try { + _receiveTracker.Restart(); + if (packet is MqttPublishPacket publishPacket) { await TryProcessReceivedPublishPacketAsync(publishPacket, cancellationToken).ConfigureAwait(false); @@ -521,11 +568,11 @@ namespace MQTTnet.Client } else if (exception is MqttCommunicationException) { - _logger.Warning(exception, "MQTT communication exception while receiving packets."); + _logger.Warning(exception, "Communication error while receiving packets."); } else { - _logger.Error(exception, "Unhandled exception while receiving packets."); + _logger.Error(exception, "Error while receiving packets."); } _packetDispatcher.Dispatch(exception); @@ -567,7 +614,7 @@ namespace MQTTnet.Client }; await SendAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); - } + } } else { @@ -576,7 +623,7 @@ namespace MQTTnet.Client } catch (Exception exception) { - _logger.Error(exception, "Unhandled exception while handling application message."); + _logger.Error(exception, "Error while handling application message."); } } @@ -626,15 +673,25 @@ namespace MQTTnet.Client return true; } - private static async Task WaitForTaskAsync(Task task, Task sender) + private async Task WaitForTaskAsync(Task task, Task sender) { - if (task == sender || task == null) + if (task == null) { return; } - if (task.IsCanceled || task.IsCompleted || task.IsFaulted) + if (task == sender) { + // Return here to avoid deadlocks, but first any eventual exception in the task + // must be handled to avoid not getting an unhandled task exception + if (!task.IsFaulted) + { + return; + } + + // By accessing the Exception property the exception is considered handled and will + // not result in an unhandled task exception later by the finalizer + _logger.Warning(task.Exception, "Error while waiting for background task."); return; } @@ -652,4 +709,4 @@ namespace MQTTnet.Client return Interlocked.CompareExchange(ref _disconnectGate, 1, 0) != 0; } } -} \ No newline at end of file +} diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs index 65a1ec9..4fd0ccf 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs @@ -256,7 +256,11 @@ namespace MQTTnet.Client.Options UseTls = true, SslProtocol = _tlsParameters.SslProtocol, AllowUntrustedCertificates = _tlsParameters.AllowUntrustedCertificates, +#if WINDOWS_UWP Certificates = _tlsParameters.Certificates?.Select(c => c.ToArray()).ToList(), +#else + Certificates = _tlsParameters.Certificates?.ToList(), +#endif CertificateValidationCallback = _tlsParameters.CertificateValidationCallback, IgnoreCertificateChainErrors = _tlsParameters.IgnoreCertificateChainErrors, IgnoreCertificateRevocationErrors = _tlsParameters.IgnoreCertificateRevocationErrors diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs index ea36baa..d1854ff 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs @@ -18,7 +18,12 @@ namespace MQTTnet.Client.Options public SslProtocols SslProtocol { get; set; } = SslProtocols.Tls12; +#if WINDOWS_UWP public IEnumerable> Certificates { get; set; } +#else + public IEnumerable Certificates { get; set; } +#endif + public bool AllowUntrustedCertificates { get; set; } diff --git a/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs b/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs index db4077d..0d1a3a5 100644 --- a/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs +++ b/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs @@ -15,8 +15,11 @@ namespace MQTTnet.Client.Options public bool IgnoreCertificateChainErrors { get; set; } public bool AllowUntrustedCertificates { get; set; } - +#if WINDOWS_UWP public List Certificates { get; set; } +#else + public List Certificates { get; set; } +#endif public SslProtocols SslProtocol { get; set; } = SslProtocols.Tls12; diff --git a/Source/MQTTnet/Client/Unsubscribing/MqttClientUnsubscribeOptionsBuilder.cs b/Source/MQTTnet/Client/Unsubscribing/MqttClientUnsubscribeOptionsBuilder.cs new file mode 100644 index 0000000..96c178f --- /dev/null +++ b/Source/MQTTnet/Client/Unsubscribing/MqttClientUnsubscribeOptionsBuilder.cs @@ -0,0 +1,60 @@ +using MQTTnet.Packets; +using System; +using System.Collections.Generic; +using System.Text; + +namespace MQTTnet.Client.Unsubscribing +{ + public class MqttClientUnsubscribeOptionsBuilder + { + private readonly MqttClientUnsubscribeOptions _unsubscribeOptions = new MqttClientUnsubscribeOptions(); + + public MqttClientUnsubscribeOptionsBuilder WithUserProperty(string name, string value) + { + if (name is null) throw new ArgumentNullException(nameof(name)); + if (value is null) throw new ArgumentNullException(nameof(value)); + + return WithUserProperty(new MqttUserProperty(name, value)); + } + + public MqttClientUnsubscribeOptionsBuilder WithUserProperty(MqttUserProperty userProperty) + { + if (userProperty is null) throw new ArgumentNullException(nameof(userProperty)); + + if (_unsubscribeOptions.UserProperties is null) + { + _unsubscribeOptions.UserProperties = new List(); + } + + _unsubscribeOptions.UserProperties.Add(userProperty); + + return this; + } + + public MqttClientUnsubscribeOptionsBuilder WithTopicFilter(string topic) + { + if (topic is null) throw new ArgumentNullException(nameof(topic)); + + if (_unsubscribeOptions.TopicFilters is null) + { + _unsubscribeOptions.TopicFilters = new List(); + } + + _unsubscribeOptions.TopicFilters.Add(topic); + + return this; + } + + public MqttClientUnsubscribeOptionsBuilder WithTopicFilter(TopicFilter topicFilter) + { + if (topicFilter is null) throw new ArgumentNullException(nameof(topicFilter)); + + return WithTopic(topicFilter.Topic); + } + + public MqttClientUnsubscribeOptions Build() + { + return _unsubscribeOptions; + } + } +} diff --git a/Source/MQTTnet/Exceptions/MqttConfigurationException.cs b/Source/MQTTnet/Exceptions/MqttConfigurationException.cs new file mode 100644 index 0000000..4d10faf --- /dev/null +++ b/Source/MQTTnet/Exceptions/MqttConfigurationException.cs @@ -0,0 +1,21 @@ +using System; + +namespace MQTTnet.Exceptions +{ + public class MqttConfigurationException : Exception + { + protected MqttConfigurationException() + { + } + + public MqttConfigurationException(Exception innerException) + : base(innerException.Message, innerException) + { + } + + public MqttConfigurationException(string message) + : base(message) + { + } + } +} diff --git a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs index 42d3241..886b937 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs @@ -139,7 +139,7 @@ namespace MQTTnet.Formatter.V5 ReasonCode = connectionValidatorContext.ReasonCode, Properties = new MqttConnAckPacketProperties { - UserProperties = connectionValidatorContext.UserProperties, + UserProperties = connectionValidatorContext.ResponseUserProperties, AuthenticationMethod = connectionValidatorContext.AuthenticationMethod, AuthenticationData = connectionValidatorContext.ResponseAuthenticationData, AssignedClientIdentifier = connectionValidatorContext.AssignedClientIdentifier, diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index d7943ad..9b2ba56 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -10,10 +10,11 @@ using System.Runtime.ExceptionServices; using System.Threading; using MQTTnet.Channel; using MQTTnet.Client.Options; +using MQTTnet.Internal; namespace MQTTnet.Implementations { - public class MqttTcpChannel : IMqttChannel + public class MqttTcpChannel : Disposable, IMqttChannel { private readonly IMqttClientOptions _clientOptions; private readonly MqttClientTcpOptions _options; @@ -72,11 +73,7 @@ namespace MQTTnet.Implementations // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 using (cancellationToken.Register(() => socket.Dispose())) { -#if NET452 || NET461 - await Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false); -#else - await socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); -#endif + await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false); } var networkStream = new NetworkStream(socket, true); @@ -98,7 +95,7 @@ namespace MQTTnet.Implementations public Task DisconnectAsync(CancellationToken cancellationToken) { - Dispose(); + Cleanup(); return Task.FromResult(0); } @@ -117,6 +114,10 @@ namespace MQTTnet.Implementations return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } } + catch (ObjectDisposedException) + { + return 0; + } catch (IOException exception) { if (exception.InnerException is SocketException socketException) @@ -143,6 +144,10 @@ namespace MQTTnet.Implementations await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } } + catch (ObjectDisposedException) + { + return; + } catch (IOException exception) { if (exception.InnerException is SocketException socketException) @@ -154,7 +159,7 @@ namespace MQTTnet.Implementations } } - public void Dispose() + private void Cleanup() { // When the stream is disposed it will also close the socket and this will also dispose it. // So there is no need to dispose the socket again. @@ -173,6 +178,15 @@ namespace MQTTnet.Implementations _stream = null; } + protected override void Dispose(bool disposing) + { + if (disposing) + { + Cleanup(); + } + base.Dispose(disposing); + } + private bool InternalUserCertificateValidationCallback(object sender, X509Certificate x509Certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { if (_options.TlsOptions.CertificateValidationCallback != null) @@ -214,7 +228,7 @@ namespace MQTTnet.Implementations foreach (var certificate in _options.TlsOptions.Certificates) { - certificates.Add(new X509Certificate2(certificate)); + certificates.Add(certificate); } return certificates; diff --git a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs index d7f4e6f..501c4da 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs @@ -8,11 +8,12 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; +using MQTTnet.Internal; using MQTTnet.Server; namespace MQTTnet.Implementations { - public class MqttTcpServerAdapter : IMqttServerAdapter + public class MqttTcpServerAdapter : Disposable, IMqttServerAdapter { private readonly List _listeners = new List(); private readonly IMqttNetChildLogger _logger; @@ -72,11 +73,11 @@ namespace MQTTnet.Implementations public Task StopAsync() { - Dispose(); + Cleanup(); return Task.FromResult(0); } - public void Dispose() + private void Cleanup() { _cancellationTokenSource?.Cancel(false); _cancellationTokenSource?.Dispose(); @@ -90,6 +91,15 @@ namespace MQTTnet.Implementations _listeners.Clear(); } + protected override void Dispose(bool disposing) + { + if (disposing) + { + Cleanup(); + } + base.Dispose(disposing); + } + private void RegisterListeners(MqttServerTcpEndpointBaseOptions options, X509Certificate2 tlsCertificate, CancellationToken cancellationToken) { if (!options.BoundInterNetworkAddress.Equals(IPAddress.None)) diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index d57888e..f2f439e 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -107,12 +107,7 @@ namespace MQTTnet.Implementations { try { -#if NET452 || NET461 - var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); -#else - var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); -#endif - + var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false); if (clientSocket == null) { continue; diff --git a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs index 38e4342..c159b91 100644 --- a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs +++ b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs @@ -6,10 +6,11 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; using MQTTnet.Client.Options; +using MQTTnet.Internal; namespace MQTTnet.Implementations { - public class MqttWebSocketChannel : IMqttChannel + public class MqttWebSocketChannel : Disposable, IMqttChannel { private readonly MqttClientWebSocketOptions _options; @@ -84,7 +85,12 @@ namespace MQTTnet.Implementations clientWebSocket.Options.ClientCertificates = new X509CertificateCollection(); foreach (var certificate in _options.TlsOptions.Certificates) { +#if WINDOWS_UWP clientWebSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); +#else + clientWebSocket.Options.ClientCertificates.Add(certificate); +#endif + } } @@ -106,7 +112,7 @@ namespace MQTTnet.Implementations await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken).ConfigureAwait(false); } - Dispose(); + Cleanup(); } public async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -136,7 +142,16 @@ namespace MQTTnet.Implementations } } - public void Dispose() + protected override void Dispose(bool disposing) + { + if (disposing) + { + Cleanup(); + } + base.Dispose(disposing); + } + + private void Cleanup() { _sendLock?.Dispose(); _sendLock = null; diff --git a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs new file mode 100644 index 0000000..ee9057a --- /dev/null +++ b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs @@ -0,0 +1,104 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks; + +namespace MQTTnet.Implementations +{ + public static class PlatformAbstractionLayer + { + public static async Task AcceptAsync(Socket socket) + { +#if NET452 || NET461 + try + { + return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false); + } + catch (ObjectDisposedException) + { + return null; + } +#else + return await socket.AcceptAsync().ConfigureAwait(false); +#endif + } + + + public static Task ConnectAsync(Socket socket, IPAddress ip, int port) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null); +#else + return socket.ConnectAsync(ip, port); +#endif + } + + public static Task ConnectAsync(Socket socket, string host, int port) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null); +#else + return socket.ConnectAsync(host, port); +#endif + } + +#if NET452 || NET461 + public class SocketWrapper + { + private readonly Socket _socket; + private readonly ArraySegment _buffer; + private readonly SocketFlags _socketFlags; + + public SocketWrapper(Socket socket, ArraySegment buffer, SocketFlags socketFlags) + { + _socket = socket; + _buffer = buffer; + _socketFlags = socketFlags; + } + + public static IAsyncResult BeginSend(AsyncCallback callback, object state) + { + var real = (SocketWrapper)state; + return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); + } + + public static IAsyncResult BeginReceive(AsyncCallback callback, object state) + { + var real = (SocketWrapper)state; + return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); + } + } +#endif + + public static Task SendAsync(Socket socket, ArraySegment buffer, SocketFlags socketFlags) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags)); +#else + return socket.SendAsync(buffer, socketFlags); +#endif + } + + public static Task ReceiveAsync(Socket socket, ArraySegment buffer, SocketFlags socketFlags) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags)); +#else + return socket.ReceiveAsync(buffer, socketFlags); +#endif + } + + public static Task CompletedTask + { + get + { +#if NET452 + return Task.FromResult(0); +#else + return Task.CompletedTask; +#endif + } + } + + } +} diff --git a/Source/MQTTnet/Internal/AsyncAutoResetEvent.cs b/Source/MQTTnet/Internal/AsyncAutoResetEvent.cs deleted file mode 100644 index cd62f07..0000000 --- a/Source/MQTTnet/Internal/AsyncAutoResetEvent.cs +++ /dev/null @@ -1,131 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - -namespace MQTTnet.Internal -{ - // Inspired from Stephen Toub (https://blogs.msdn.microsoft.com/pfxteam/2012/02/11/building-async-coordination-primitives-part-2-asyncautoresetevent/) and Chris Gillum (https://stackoverflow.com/a/43012490) - public class AsyncAutoResetEvent - { - private readonly LinkedList> _waiters = new LinkedList>(); - - private bool _isSignaled; - - public AsyncAutoResetEvent() - : this(false) - { - } - - public AsyncAutoResetEvent(bool signaled) - { - _isSignaled = signaled; - } - - public int WaitersCount - { - get - { - lock (_waiters) - { - return _waiters.Count; - } - } - } - - public Task WaitOneAsync() - { - return WaitOneAsync(CancellationToken.None); - } - - public Task WaitOneAsync(TimeSpan timeout) - { - return WaitOneAsync(timeout, CancellationToken.None); - } - - public Task WaitOneAsync(CancellationToken cancellationToken) - { - return WaitOneAsync(Timeout.InfiniteTimeSpan, cancellationToken); - } - - public async Task WaitOneAsync(TimeSpan timeout, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - TaskCompletionSource tcs; - - lock (_waiters) - { - if (_isSignaled) - { - _isSignaled = false; - return true; - } - - if (timeout == TimeSpan.Zero) - { - return _isSignaled; - } - - tcs = new TaskCompletionSource(); - _waiters.AddLast(tcs); - } - - Task winner; - if (timeout == Timeout.InfiniteTimeSpan) - { - using (cancellationToken.Register(() => { tcs.TrySetCanceled(); })) - { - await tcs.Task.ConfigureAwait(false); - winner = tcs.Task; - } - } - else - { - winner = await Task.WhenAny(tcs.Task, Task.Delay(timeout, cancellationToken)).ConfigureAwait(false); - } - - var taskWasSignaled = winner == tcs.Task; - if (taskWasSignaled) - { - return true; - } - - // We timed-out; remove our reference to the task. - // This is an O(n) operation since waiters is a LinkedList. - lock (_waiters) - { - _waiters.Remove(tcs); - - if (winner.Status == TaskStatus.Canceled) - { - throw new OperationCanceledException(cancellationToken); - } - - throw new TimeoutException(); - } - } - - public void Set() - { - TaskCompletionSource toRelease = null; - - lock (_waiters) - { - if (_waiters.Count > 0) - { - // Signal the first task in the waiters list. - toRelease = _waiters.First.Value; - _waiters.RemoveFirst(); - } - else if (!_isSignaled) - { - // No tasks are pending - _isSignaled = true; - } - } - - toRelease?.TrySetResult(true); - } - } -} diff --git a/Source/MQTTnet/Internal/BlockingQueue.cs b/Source/MQTTnet/Internal/BlockingQueue.cs index 485f644..2fa21be 100644 --- a/Source/MQTTnet/Internal/BlockingQueue.cs +++ b/Source/MQTTnet/Internal/BlockingQueue.cs @@ -4,11 +4,11 @@ using System.Threading; namespace MQTTnet.Internal { - public class BlockingQueue + public class BlockingQueue : Disposable { private readonly object _syncRoot = new object(); private readonly LinkedList _items = new LinkedList(); - private readonly ManualResetEvent _gate = new ManualResetEvent(false); + private readonly ManualResetEventSlim _gate = new ManualResetEventSlim(false); public int Count { @@ -32,7 +32,7 @@ namespace MQTTnet.Internal } } - public TItem Dequeue() + public TItem Dequeue(CancellationToken cancellationToken = default(CancellationToken)) { while (true) { @@ -52,11 +52,11 @@ namespace MQTTnet.Internal } } - _gate.WaitOne(); + _gate.Wait(cancellationToken); } } - public TItem PeekAndWait() + public TItem PeekAndWait(CancellationToken cancellationToken = default(CancellationToken)) { while (true) { @@ -73,7 +73,7 @@ namespace MQTTnet.Internal } } - _gate.WaitOne(); + _gate.Wait(cancellationToken); } } @@ -108,5 +108,14 @@ namespace MQTTnet.Internal _items.Clear(); } } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _gate.Dispose(); + } + base.Dispose(disposing); + } } } diff --git a/Source/MQTTnet/Internal/Disposable.cs b/Source/MQTTnet/Internal/Disposable.cs new file mode 100644 index 0000000..2ce3423 --- /dev/null +++ b/Source/MQTTnet/Internal/Disposable.cs @@ -0,0 +1,57 @@ +using System; + +namespace MQTTnet.Internal +{ + public class Disposable : IDisposable + { + protected bool IsDisposed => _isDisposed; + + protected void ThrowIfDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(GetType().Name); + } + } + + + #region IDisposable Support + + private bool _isDisposed = false; // To detect redundant calls + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + // TODO: dispose managed state (managed objects). + } + + // TODO: free unmanaged resources (unmanaged objects) and override a finalizer below. + // TODO: set large fields to null. + } + + // TODO: override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. + // ~Disposable() + // { + // // Do not change this code. Put cleanup code in Dispose(bool disposing) above. + // Dispose(false); + // } + + // This code added to correctly implement the disposable pattern. + public void Dispose() + { + if (_isDisposed) + { + return; + } + + _isDisposed = true; + + // Do not change this code. Put cleanup code in Dispose(bool disposing) above. + Dispose(true); + // TODO: uncomment the following line if the finalizer is overridden above. + // GC.SuppressFinalize(this); + } + #endregion + } +} diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index 61ca517..28e3f3d 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -62,5 +62,9 @@ + + + + \ No newline at end of file diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs index 19df6d4..b172290 100644 --- a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs @@ -2,13 +2,14 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Exceptions; +using MQTTnet.Internal; using MQTTnet.Packets; namespace MQTTnet.PacketDispatcher { - public sealed class MqttPacketAwaiter : IMqttPacketAwaiter where TPacket : MqttBasePacket + public sealed class MqttPacketAwaiter : Disposable, IMqttPacketAwaiter where TPacket : MqttBasePacket { - private readonly TaskCompletionSource _taskCompletionSource = new TaskCompletionSource(); + private readonly TaskCompletionSource _taskCompletionSource; private readonly ushort? _packetIdentifier; private readonly MqttPacketDispatcher _owningPacketDispatcher; @@ -16,13 +17,18 @@ namespace MQTTnet.PacketDispatcher { _packetIdentifier = packetIdentifier; _owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher)); +#if NET452 + _taskCompletionSource = new TaskCompletionSource(); +#else + _taskCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); +#endif } public async Task WaitOneAsync(TimeSpan timeout) { using (var timeoutToken = new CancellationTokenSource(timeout)) { - timeoutToken.Token.Register(() => _taskCompletionSource.TrySetException(new MqttCommunicationTimedOutException())); + timeoutToken.Token.Register(() => Fail(new MqttCommunicationTimedOutException())); var packet = await _taskCompletionSource.Task.ConfigureAwait(false); return (TPacket)packet; @@ -32,29 +38,56 @@ namespace MQTTnet.PacketDispatcher public void Complete(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); - + + +#if NET452 // To prevent deadlocks it is required to call the _TrySetResult_ method // from a new thread because the awaiting code will not(!) be executed in // a new thread automatically (due to await). Furthermore _this_ thread will // do it. But _this_ thread is also reading incoming packets -> deadlock. + // NET452 does not support RunContinuationsAsynchronously Task.Run(() => _taskCompletionSource.TrySetResult(packet)); +#else + _taskCompletionSource.TrySetResult(packet); +#endif } public void Fail(Exception exception) { if (exception == null) throw new ArgumentNullException(nameof(exception)); - +#if NET452 + // To prevent deadlocks it is required to call the _TrySetResult_ method + // from a new thread because the awaiting code will not(!) be executed in + // a new thread automatically (due to await). Furthermore _this_ thread will + // do it. But _this_ thread is also reading incoming packets -> deadlock. + // NET452 does not support RunContinuationsAsynchronously Task.Run(() => _taskCompletionSource.TrySetException(exception)); +#else + _taskCompletionSource.TrySetException(exception); +#endif } public void Cancel() { +#if NET452 + // To prevent deadlocks it is required to call the _TrySetResult_ method + // from a new thread because the awaiting code will not(!) be executed in + // a new thread automatically (due to await). Furthermore _this_ thread will + // do it. But _this_ thread is also reading incoming packets -> deadlock. + // NET452 does not support RunContinuationsAsynchronously Task.Run(() => _taskCompletionSource.TrySetCanceled()); +#else + _taskCompletionSource.TrySetCanceled(); +#endif } - public void Dispose() + protected override void Dispose(bool disposing) { - _owningPacketDispatcher.RemovePacketAwaiter(_packetIdentifier); + if (disposing) + { + _owningPacketDispatcher.RemovePacketAwaiter(_packetIdentifier); + } + base.Dispose(disposing); } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/IMqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/IMqttRetainedMessagesManager.cs new file mode 100644 index 0000000..6ffdd2b --- /dev/null +++ b/Source/MQTTnet/Server/IMqttRetainedMessagesManager.cs @@ -0,0 +1,21 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using MQTTnet.Diagnostics; + +namespace MQTTnet.Server +{ + public interface IMqttRetainedMessagesManager + { + Task Start(IMqttServerOptions options, IMqttNetChildLogger logger); + + Task LoadMessagesAsync(); + + Task ClearMessagesAsync(); + + Task HandleMessageAsync(string clientId, MqttApplicationMessage applicationMessage); + + Task> GetMessagesAsync(); + + Task> GetSubscribedMessagesAsync(ICollection topicFilters); + } +} diff --git a/Source/MQTTnet/Server/IMqttServerOptions.cs b/Source/MQTTnet/Server/IMqttServerOptions.cs index 7c5fde4..7df6f54 100644 --- a/Source/MQTTnet/Server/IMqttServerOptions.cs +++ b/Source/MQTTnet/Server/IMqttServerOptions.cs @@ -15,14 +15,15 @@ namespace MQTTnet.Server IMqttServerConnectionValidator ConnectionValidator { get; } IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; } + IMqttServerUnsubscriptionInterceptor UnsubscriptionInterceptor { get; } IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; } IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; } MqttServerTcpEndpointOptions DefaultEndpointOptions { get; } MqttServerTlsTcpEndpointOptions TlsEndpointOptions { get; } - IMqttServerStorage Storage { get; } - + IMqttServerStorage Storage { get; } + IMqttRetainedMessagesManager RetainedMessagesManager { get; } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/IMqttServerUnsubscriptionInterceptor.cs b/Source/MQTTnet/Server/IMqttServerUnsubscriptionInterceptor.cs new file mode 100644 index 0000000..9669383 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerUnsubscriptionInterceptor.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttServerUnsubscriptionInterceptor + { + Task InterceptUnsubscriptionAsync(MqttUnsubscriptionInterceptorContext context); + } +} diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index e71d1a8..c9e2553 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -21,7 +21,7 @@ namespace MQTTnet.Server private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly CancellationTokenSource _cancellationToken = new CancellationTokenSource(); - private readonly MqttRetainedMessagesManager _retainedMessagesManager; + private readonly IMqttRetainedMessagesManager _retainedMessagesManager; private readonly MqttClientKeepAliveMonitor _keepAliveMonitor; private readonly MqttClientSessionsManager _sessionsManager; @@ -36,7 +36,7 @@ namespace MQTTnet.Server private Task _packageReceiverTask; private DateTime _lastPacketReceivedTimestamp; private DateTime _lastNonKeepAlivePacketReceivedTimestamp; - + private long _receivedPacketsCount; private long _sentPacketsCount = 1; // Start with 1 because the CONNECT packet is not counted anywhere. private long _receivedApplicationMessagesCount; @@ -48,14 +48,14 @@ namespace MQTTnet.Server MqttClientSession session, IMqttServerOptions serverOptions, MqttClientSessionsManager sessionsManager, - MqttRetainedMessagesManager retainedMessagesManager, + IMqttRetainedMessagesManager retainedMessagesManager, IMqttNetChildLogger logger) { Session = session ?? throw new ArgumentNullException(nameof(session)); _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); _sessionsManager = sessionsManager ?? throw new ArgumentNullException(nameof(sessionsManager)); _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager)); - + _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; _endpoint = _channelAdapter.Endpoint; @@ -76,7 +76,7 @@ namespace MQTTnet.Server public string ClientId => ConnectPacket.ClientId; public MqttClientSession Session { get; } - + public async Task StopAsync() { StopInternal(); @@ -112,25 +112,25 @@ namespace MQTTnet.Server status.BytesSent = _channelAdapter.BytesSent; status.BytesReceived = _channelAdapter.BytesReceived; } - + public void Dispose() { _cancellationToken.Dispose(); } - public Task RunAsync() + public Task RunAsync(MqttConnectionValidatorContext connectionValidatorContext) { - _packageReceiverTask = RunInternalAsync(); + _packageReceiverTask = RunInternalAsync(connectionValidatorContext); return _packageReceiverTask; } - private async Task RunInternalAsync() + private async Task RunInternalAsync(MqttConnectionValidatorContext connectionValidatorContext) { var disconnectType = MqttClientDisconnectType.NotClean; try { _logger.Info("Client '{0}': Session started.", ClientId); - + _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; _channelAdapter.ReadingPacketCompletedCallback = OnAdapterReadingPacketCompleted; @@ -142,12 +142,8 @@ namespace MQTTnet.Server _keepAliveMonitor.Start(ConnectPacket.KeepAlivePeriod, _cancellationToken.Token); await SendAsync( - new MqttConnAckPacket - { - ReturnCode = MqttConnectReturnCode.ConnectionAccepted, - ReasonCode = MqttConnectReasonCode.Success, - IsSessionPresent = !Session.IsCleanSession - }).ConfigureAwait(false); + _channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext) + ).ConfigureAwait(false); Session.IsCleanSession = false; @@ -248,7 +244,7 @@ namespace MQTTnet.Server _channelAdapter.ReadingPacketCompletedCallback = null; _logger.Info("Client '{0}': Session stopped.", ClientId); - + _packageReceiverTask = null; } diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index d165001..d097b9f 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -52,7 +52,7 @@ namespace MQTTnet.Server ApplicationMessagesQueue.Enqueue(applicationMessage, senderClientId, checkSubscriptionsResult.QualityOfServiceLevel, isRetainedApplicationMessage); } - public async Task SubscribeAsync(ICollection topicFilters, MqttRetainedMessagesManager retainedMessagesManager) + public async Task SubscribeAsync(ICollection topicFilters, IMqttRetainedMessagesManager retainedMessagesManager) { await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false); diff --git a/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs b/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs index 901ac75..0cf19c8 100644 --- a/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs +++ b/Source/MQTTnet/Server/MqttClientSessionApplicationMessagesQueue.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; namespace MQTTnet.Server { - public class MqttClientSessionApplicationMessagesQueue : IDisposable + public class MqttClientSessionApplicationMessagesQueue : Disposable { private readonly AsyncQueue _messageQueue = new AsyncQueue(); @@ -71,9 +71,14 @@ namespace MQTTnet.Server } } - public void Dispose() + protected override void Dispose(bool disposing) { - _messageQueue.Dispose(); + if (disposing) + { + _messageQueue.Dispose(); + } + + base.Dispose(disposing); } } } diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index db70e95..28c163d 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -13,7 +13,7 @@ using MQTTnet.Server.Status; namespace MQTTnet.Server { - public class MqttClientSessionsManager : IDisposable + public class MqttClientSessionsManager : Disposable { private readonly AsyncQueue _messageQueue = new AsyncQueue(); @@ -25,13 +25,13 @@ namespace MQTTnet.Server private readonly CancellationToken _cancellationToken; private readonly MqttServerEventDispatcher _eventDispatcher; - private readonly MqttRetainedMessagesManager _retainedMessagesManager; + private readonly IMqttRetainedMessagesManager _retainedMessagesManager; private readonly IMqttServerOptions _options; private readonly IMqttNetChildLogger _logger; public MqttClientSessionsManager( IMqttServerOptions options, - MqttRetainedMessagesManager retainedMessagesManager, + IMqttRetainedMessagesManager retainedMessagesManager, CancellationToken cancellationToken, MqttServerEventDispatcher eventDispatcher, IMqttNetChildLogger logger) @@ -72,7 +72,7 @@ namespace MQTTnet.Server { var clientStatus = new MqttClientStatus(connection); connection.FillStatus(clientStatus); - + var sessionStatus = new MqttSessionStatus(connection.Session, this); connection.Session.FillStatus(sessionStatus); clientStatus.Session = sessionStatus; @@ -91,7 +91,7 @@ namespace MQTTnet.Server { var sessionStatus = new MqttSessionStatus(session, this); session.FillStatus(sessionStatus); - + result.Add(sessionStatus); } @@ -145,9 +145,13 @@ namespace MQTTnet.Server _logger.Verbose("Session for client '{0}' deleted.", clientId); } - public void Dispose() + protected override void Dispose(bool disposing) { - _messageQueue?.Dispose(); + if (disposing) + { + _messageQueue?.Dispose(); + } + base.Dispose(disposing); } private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken) @@ -230,6 +234,7 @@ namespace MQTTnet.Server { var disconnectType = MqttClientDisconnectType.NotClean; string clientId = null; + var clientWasConnected = true; try { @@ -240,12 +245,13 @@ namespace MQTTnet.Server return; } - clientId = connectPacket.ClientId; - var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); + clientId = connectPacket.ClientId; + if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { + clientWasConnected = false; // Send failure response here without preparing a session. The result for a successful connect // will be sent from the session itself. var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); @@ -257,8 +263,8 @@ namespace MQTTnet.Server var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); - - disconnectType = await connection.RunAsync().ConfigureAwait(false); + + disconnectType = await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -269,21 +275,24 @@ namespace MQTTnet.Server } finally { - if (clientId != null) + if (clientWasConnected) { - _connections.TryRemove(clientId, out _); - - if (!_options.EnablePersistentSessions) + if (clientId != null) { - await DeleteSessionAsync(clientId).ConfigureAwait(false); + _connections.TryRemove(clientId, out _); + + if (!_options.EnablePersistentSessions) + { + await DeleteSessionAsync(clientId).ConfigureAwait(false); + } } - } - await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); - if (clientId != null) - { - await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + if (clientId != null) + { + await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + } } } } @@ -328,13 +337,13 @@ namespace MQTTnet.Server { await existingConnection.StopAsync().ConfigureAwait(false); } - + if (isSessionPresent) { if (connectPacket.CleanSession) { session = null; - + _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId); } else diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index c84a018..deeadf4 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -107,9 +107,16 @@ namespace MQTTnet.Server PacketIdentifier = unsubscribePacket.PacketIdentifier }; - lock (_subscriptions) + foreach (var topicFilter in unsubscribePacket.TopicFilters) { - foreach (var topicFilter in unsubscribePacket.TopicFilters) + var interceptorContext = await InterceptUnsubscribeAsync(topicFilter).ConfigureAwait(false); + if (!interceptorContext.AcceptUnsubscription) + { + unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.ImplementationSpecificError); + continue; + } + + lock (_subscriptions) { if (_subscriptions.Remove(topicFilter)) { @@ -130,19 +137,23 @@ namespace MQTTnet.Server return unsubAckPacket; } - public Task UnsubscribeAsync(IEnumerable topicFilters) + public async Task UnsubscribeAsync(IEnumerable topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - lock (_subscriptions) + foreach (var topicFilter in topicFilters) { - foreach (var topicFilter in topicFilters) + var interceptorContext = await InterceptUnsubscribeAsync(topicFilter).ConfigureAwait(false); + if (!interceptorContext.AcceptUnsubscription) { - _subscriptions.Remove(topicFilter); + continue; } - } - return Task.FromResult(0); + lock (_subscriptions) + { + _subscriptions.Remove(topicFilter); + } + } } public CheckSubscriptionsResult CheckSubscriptions(string topic, MqttQualityOfServiceLevel qosLevel) @@ -206,6 +217,17 @@ namespace MQTTnet.Server return context; } + private async Task InterceptUnsubscribeAsync(string topicFilter) + { + var context = new MqttUnsubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items); + if (_serverOptions.UnsubscriptionInterceptor != null) + { + await _serverOptions.UnsubscriptionInterceptor.InterceptUnsubscriptionAsync(context).ConfigureAwait(false); + } + + return context; + } + private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet subscribedQoSLevels) { MqttQualityOfServiceLevel effectiveQoS; diff --git a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs index 2e6af16..f4ebe48 100644 --- a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs +++ b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs @@ -3,24 +3,26 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using MQTTnet.Diagnostics; +using MQTTnet.Implementations; using MQTTnet.Internal; namespace MQTTnet.Server { - public class MqttRetainedMessagesManager + public class MqttRetainedMessagesManager : IMqttRetainedMessagesManager { private readonly byte[] _emptyArray = new byte[0]; private readonly AsyncLock _messagesLock = new AsyncLock(); private readonly Dictionary _messages = new Dictionary(); - private readonly IMqttNetChildLogger _logger; - private readonly IMqttServerOptions _options; + private IMqttNetChildLogger _logger; + private IMqttServerOptions _options; - public MqttRetainedMessagesManager(IMqttServerOptions options, IMqttNetChildLogger logger) + public Task Start(IMqttServerOptions options, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _logger = logger.CreateChildLogger(nameof(MqttRetainedMessagesManager)); _options = options ?? throw new ArgumentNullException(nameof(options)); + return PlatformAbstractionLayer.CompletedTask; } public async Task LoadMessagesAsync() @@ -103,7 +105,7 @@ namespace MQTTnet.Server } } - public async Task> GetSubscribedMessagesAsync(ICollection topicFilters) + public async Task> GetSubscribedMessagesAsync(ICollection topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -128,7 +130,7 @@ namespace MQTTnet.Server break; } } - + return matchingRetainedMessages; } diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index f902fc0..4c5ab62 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -7,6 +7,7 @@ using MQTTnet.Adapter; using MQTTnet.Client.Publishing; using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; +using MQTTnet.Exceptions; using MQTTnet.Protocol; using MQTTnet.Server.Status; @@ -19,7 +20,7 @@ namespace MQTTnet.Server private readonly IMqttNetChildLogger _logger; private MqttClientSessionsManager _clientSessionsManager; - private MqttRetainedMessagesManager _retainedMessagesManager; + private IMqttRetainedMessagesManager _retainedMessagesManager; private CancellationTokenSource _cancellationTokenSource; public MqttServer(IEnumerable adapters, IMqttNetChildLogger logger) @@ -48,7 +49,7 @@ namespace MQTTnet.Server get => _eventDispatcher.ClientDisconnectedHandler; set => _eventDispatcher.ClientDisconnectedHandler = value; } - + public IMqttServerClientSubscribedTopicHandler ClientSubscribedTopicHandler { get => _eventDispatcher.ClientSubscribedTopicHandler; @@ -60,7 +61,7 @@ namespace MQTTnet.Server get => _eventDispatcher.ClientUnsubscribedTopicHandler; set => _eventDispatcher.ClientUnsubscribedTopicHandler = value; } - + public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get => _eventDispatcher.ApplicationMessageReceivedHandler; @@ -117,11 +118,14 @@ namespace MQTTnet.Server { Options = options ?? throw new ArgumentNullException(nameof(options)); + if (Options.RetainedMessagesManager == null) throw new MqttConfigurationException("options.RetainedMessagesManager should not be null."); + if (_cancellationTokenSource != null) throw new InvalidOperationException("The server is already started."); _cancellationTokenSource = new CancellationTokenSource(); - _retainedMessagesManager = new MqttRetainedMessagesManager(Options, _logger); + _retainedMessagesManager = Options.RetainedMessagesManager; + await _retainedMessagesManager.Start(Options, _logger); await _retainedMessagesManager.LoadMessagesAsync().ConfigureAwait(false); _clientSessionsManager = new MqttClientSessionsManager(Options, _retainedMessagesManager, _cancellationTokenSource.Token, _eventDispatcher, _logger); @@ -150,9 +154,9 @@ namespace MQTTnet.Server { return; } - + await _clientSessionsManager.StopAsync().ConfigureAwait(false); - + _cancellationTokenSource.Cancel(false); foreach (var adapter in _adapters) diff --git a/Source/MQTTnet/Server/MqttServerOptions.cs b/Source/MQTTnet/Server/MqttServerOptions.cs index 7147ef8..9773e72 100644 --- a/Source/MQTTnet/Server/MqttServerOptions.cs +++ b/Source/MQTTnet/Server/MqttServerOptions.cs @@ -21,11 +21,15 @@ namespace MQTTnet.Server public IMqttServerConnectionValidator ConnectionValidator { get; set; } public IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; set; } - + public IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; set; } public IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; set; } + public IMqttServerUnsubscriptionInterceptor UnsubscriptionInterceptor { get; set; } + public IMqttServerStorage Storage { get; set; } + + public IMqttRetainedMessagesManager RetainedMessagesManager { get; set; } = new MqttRetainedMessagesManager(); } } diff --git a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs index c25af84..2970fab 100644 --- a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs +++ b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs @@ -57,7 +57,7 @@ namespace MQTTnet.Server _options.DefaultEndpointOptions.IsEnabled = false; return this; } - + public MqttServerOptionsBuilder WithEncryptedEndpoint() { _options.TlsEndpointOptions.IsEnabled = true; @@ -118,13 +118,19 @@ namespace MQTTnet.Server return this; } #endif - + public MqttServerOptionsBuilder WithStorage(IMqttServerStorage value) { _options.Storage = value; return this; } + public MqttServerOptionsBuilder WithRetainedMessagesManager(IMqttRetainedMessagesManager value) + { + _options.RetainedMessagesManager = value; + return this; + } + public MqttServerOptionsBuilder WithConnectionValidator(IMqttServerConnectionValidator value) { _options.ConnectionValidator = value; @@ -155,6 +161,12 @@ namespace MQTTnet.Server return this; } + public MqttServerOptionsBuilder WithUnsubscriptionInterceptor(IMqttServerUnsubscriptionInterceptor value) + { + _options.UnsubscriptionInterceptor = value; + return this; + } + public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action value) { _options.SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(value); diff --git a/Source/MQTTnet/Server/MqttUnsubscriptionInterceptorContext.cs b/Source/MQTTnet/Server/MqttUnsubscriptionInterceptorContext.cs new file mode 100644 index 0000000..b33cbac --- /dev/null +++ b/Source/MQTTnet/Server/MqttUnsubscriptionInterceptorContext.cs @@ -0,0 +1,27 @@ +using System.Collections.Generic; + +namespace MQTTnet.Server +{ + public class MqttUnsubscriptionInterceptorContext + { + public MqttUnsubscriptionInterceptorContext(string clientId, string topic, IDictionary sessionItems) + { + ClientId = clientId; + Topic = topic; + SessionItems = sessionItems; + } + + public string ClientId { get; } + + public string Topic { get; set; } + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + + public bool AcceptUnsubscription { get; set; } = true; + + public bool CloseConnection { get; set; } + } +} diff --git a/Tests/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj b/Tests/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj index c81f91c..06b5976 100644 --- a/Tests/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj +++ b/Tests/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj @@ -6,9 +6,9 @@ - - - + + + diff --git a/Tests/MQTTnet.Benchmarks/Configurations/RuntimeCompareConfig.cs b/Tests/MQTTnet.Benchmarks/Configurations/RuntimeCompareConfig.cs index 608271e..5838424 100644 --- a/Tests/MQTTnet.Benchmarks/Configurations/RuntimeCompareConfig.cs +++ b/Tests/MQTTnet.Benchmarks/Configurations/RuntimeCompareConfig.cs @@ -9,8 +9,8 @@ namespace MQTTnet.Benchmarks.Configurations { public RuntimeCompareConfig() { - Add(Job.Default.With(Runtime.Clr)); - Add(Job.Default.With(Runtime.Core).With(CsProjCoreToolchain.NetCoreApp21)); + Add(Job.Default.With(ClrRuntime.Net472)); + Add(Job.Default.With(CoreRuntime.Core22).With(CsProjCoreToolchain.NetCoreApp22)); } } diff --git a/Tests/MQTTnet.Benchmarks/LoggerBenchmark.cs b/Tests/MQTTnet.Benchmarks/LoggerBenchmark.cs index 0039434..cfc88d4 100644 --- a/Tests/MQTTnet.Benchmarks/LoggerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/LoggerBenchmark.cs @@ -1,9 +1,10 @@ using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; using MQTTnet.Diagnostics; namespace MQTTnet.Benchmarks { - [ClrJob] + [SimpleJob(RuntimeMoniker.Net461)] [RPlotExporter] [MemoryDiagnoser] public class LoggerBenchmark diff --git a/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj index 03aabc8..b34c5a3 100644 --- a/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj +++ b/Tests/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -9,7 +9,7 @@ - + diff --git a/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs index c821ee0..99fe030 100644 --- a/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs @@ -1,11 +1,12 @@ using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; using MQTTnet.Client; using MQTTnet.Client.Options; using MQTTnet.Server; namespace MQTTnet.Benchmarks { - [ClrJob] + [SimpleJob(RuntimeMoniker.Net461)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class MessageProcessingBenchmark diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs index 51e7ecb..00433cf 100644 --- a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -8,10 +8,11 @@ using MQTTnet.Adapter; using MQTTnet.Channel; using MQTTnet.Formatter; using MQTTnet.Formatter.V3; +using BenchmarkDotNet.Jobs; namespace MQTTnet.Benchmarks { - [ClrJob] + [SimpleJob(RuntimeMoniker.Net461)] [RPlotExporter] [MemoryDiagnoser] public class SerializerBenchmark diff --git a/Tests/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs b/Tests/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs index bbce630..2df92ad 100644 --- a/Tests/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs @@ -1,10 +1,11 @@ using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; using MQTTnet.Server; using System; namespace MQTTnet.Benchmarks { - [ClrJob] + [SimpleJob(RuntimeMoniker.Net461)] [RPlotExporter] [MemoryDiagnoser] public class TopicFilterComparerBenchmark diff --git a/Tests/MQTTnet.Core.Tests/AsyncAutoResentEvent_Tests.cs b/Tests/MQTTnet.Core.Tests/AsyncAutoResentEvent_Tests.cs deleted file mode 100644 index d72712d..0000000 --- a/Tests/MQTTnet.Core.Tests/AsyncAutoResentEvent_Tests.cs +++ /dev/null @@ -1,237 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Internal; - -namespace MQTTnet.Tests -{ - [TestClass] - // Inspired from the vs-threading tests (https://github.com/Microsoft/vs-threading/blob/master/src/Microsoft.VisualStudio.Threading.Tests/AsyncAutoResetEventTests.cs) - public class AsyncAutoResetEvent_Tests - { - private readonly AsyncAutoResetEvent _aare; - - public AsyncAutoResetEvent_Tests() - { - _aare = new AsyncAutoResetEvent(); - } - - [TestMethod] - public async Task Cleanup_Waiters() - { - var @lock = new AsyncAutoResetEvent(); - - var waitOnePassed = false; - -#pragma warning disable 4014 - Task.Run(async () => -#pragma warning restore 4014 - { - await @lock.WaitOneAsync(TimeSpan.FromSeconds(2)); - waitOnePassed = true; - }); - - await Task.Delay(500); - - Assert.AreEqual(1, @lock.WaitersCount); - - @lock.Set(); - - await Task.Delay(1000); - - Assert.IsTrue(waitOnePassed); - Assert.AreEqual(0, @lock.WaitersCount); - } - - [TestMethod] - public async Task SingleThreadedPulse() - { - for (int i = 0; i < 5; i++) - { - var t = _aare.WaitOneAsync(); - Assert.IsFalse(t.IsCompleted); - _aare.Set(); - await t; - Assert.IsTrue(t.IsCompleted); - } - } - - [TestMethod] - public async Task MultipleSetOnlySignalsOnce() - { - _aare.Set(); - _aare.Set(); - await _aare.WaitOneAsync(); - var t = _aare.WaitOneAsync(); - Assert.IsFalse(t.IsCompleted); - await Task.Delay(500); - Assert.IsFalse(t.IsCompleted); - _aare.Set(); - await t; - Assert.IsTrue(t.IsCompleted); - } - - [TestMethod] - public async Task OrderPreservingQueue() - { - var waiters = new Task[5]; - for (int i = 0; i < waiters.Length; i++) - { - waiters[i] = _aare.WaitOneAsync(); - } - - for (int i = 0; i < waiters.Length; i++) - { - _aare.Set(); - await waiters[i].ConfigureAwait(false); - } - } - - // This test does not work in appveyor but on local machine it does!? - /////// - /////// Verifies that inlining continuations do not have to complete execution before Set() returns. - /////// - ////[TestMethod] - ////public async Task SetReturnsBeforeInlinedContinuations() - ////{ - //// var setReturned = new ManualResetEventSlim(); - //// var inlinedContinuation = _aare.WaitOneAsync() - //// .ContinueWith(delegate - //// { - //// // Arrange to synchronously block the continuation until Set() has returned, - //// // which would deadlock if Set does not return until inlined continuations complete. - //// Assert.IsTrue(setReturned.Wait(500)); - //// }); - //// await Task.Delay(100); - //// _aare.Set(); - //// setReturned.Set(); - //// Assert.IsTrue(inlinedContinuation.Wait(500)); - ////} - - [TestMethod] - public void WaitAsync_WithCancellationToken() - { - var cts = new CancellationTokenSource(); - Task waitTask = _aare.WaitOneAsync(cts.Token); - Assert.IsFalse(waitTask.IsCompleted); - - // Cancel the request and ensure that it propagates to the task. - cts.Cancel(); - try - { - waitTask.GetAwaiter().GetResult(); - Assert.IsTrue(false, "Task was expected to transition to a canceled state."); - } - catch (OperationCanceledException) - { - } - - // Now set the event and verify that a future waiter gets the signal immediately. - _aare.Set(); - waitTask = _aare.WaitOneAsync(); - Assert.AreEqual(TaskStatus.WaitingForActivation, waitTask.Status); - } - - [TestMethod] - public void WaitAsync_WithCancellationToken_Precanceled() - { - // We construct our own pre-canceled token so that we can do - // a meaningful identity check later. - var tokenSource = new CancellationTokenSource(); - tokenSource.Cancel(); - var token = tokenSource.Token; - - // Verify that a pre-set signal is not reset by a canceled wait request. - _aare.Set(); - try - { - _aare.WaitOneAsync(token).GetAwaiter().GetResult(); - Assert.IsTrue(false, "Task was expected to transition to a canceled state."); - } - catch (OperationCanceledException ex) - { - Assert.AreEqual(token, ex.CancellationToken); - } - - // Verify that the signal was not acquired. - Task waitTask = _aare.WaitOneAsync(); - Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); - } - - [TestMethod] - public async Task WaitAsync_WithTimeout() - { - Task waitTask = _aare.WaitOneAsync(TimeSpan.FromMilliseconds(500)); - Assert.IsFalse(waitTask.IsCompleted); - - // Cancel the request and ensure that it propagates to the task. - await Task.Delay(1000).ConfigureAwait(false); - try - { - waitTask.GetAwaiter().GetResult(); - Assert.IsTrue(false, "Task was expected to transition to a timeout state."); - } - catch (TimeoutException) - { - Assert.IsTrue(true); - } - - // Now set the event and verify that a future waiter gets the signal immediately. - _aare.Set(); - waitTask = _aare.WaitOneAsync(TimeSpan.FromMilliseconds(500)); - Assert.AreEqual(TaskStatus.RanToCompletion, waitTask.Status); - } - - [TestMethod] - public void WaitAsync_Canceled_DoesNotInlineContinuations() - { - var cts = new CancellationTokenSource(); - var task = _aare.WaitOneAsync(cts.Token); - - var completingActionFinished = new ManualResetEventSlim(); - var continuation = task.ContinueWith( - _ => Assert.IsTrue(completingActionFinished.Wait(500)), - CancellationToken.None, - TaskContinuationOptions.None, - TaskScheduler.Default); - - cts.Cancel(); - completingActionFinished.Set(); - - // Rethrow the exception if it turned out it deadlocked. - continuation.GetAwaiter().GetResult(); - } - - [TestMethod] - public async Task AsyncAutoResetEvent() - { - var aare = new AsyncAutoResetEvent(); - - var globalI = 0; -#pragma warning disable 4014 - Task.Run(async () => -#pragma warning restore 4014 - { - await aare.WaitOneAsync(CancellationToken.None); - globalI += 1; - }); - -#pragma warning disable 4014 - Task.Run(async () => -#pragma warning restore 4014 - { - await aare.WaitOneAsync(CancellationToken.None); - globalI += 2; - }); - - await Task.Delay(500); - aare.Set(); - await Task.Delay(500); - aare.Set(); - await Task.Delay(100); - - Assert.AreEqual(3, globalI); - } - } -} \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj b/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj index 07812bc..c0749f7 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj +++ b/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj @@ -1,14 +1,14 @@  - netcoreapp3.1 + netcoreapp3.1;net461 false - - - + + + diff --git a/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs b/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs index 317c709..f033b5b 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MQTTv5/Client_Tests.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; @@ -17,10 +18,12 @@ namespace MQTTnet.Tests.MQTTv5 [TestClass] public class Client_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Connect_With_New_Mqtt_Features() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -57,6 +60,65 @@ namespace MQTTnet.Tests.MQTTv5 Assert.AreEqual(2, receivedMessage.UserProperties.Count); } } + [TestMethod] + public async Task Connect_With_AssignedClientId() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + string serverConnectedClientId = null; + string serverDisconnectedClientId = null; + string clientAssignedClientId = null; + + // Arrange server + var disconnectedMre = new ManualResetEventSlim(); + var serverOptions = new MqttServerOptionsBuilder() + .WithConnectionValidator((context) => + { + if (string.IsNullOrEmpty(context.ClientId)) + { + context.AssignedClientIdentifier = "test123"; + context.ReasonCode = MqttConnectReasonCode.Success; + } + }); + await testEnvironment.StartServerAsync(serverOptions); + testEnvironment.Server.UseClientConnectedHandler((args) => + { + serverConnectedClientId = args.ClientId; + }); + testEnvironment.Server.UseClientDisconnectedHandler((args) => + { + serverDisconnectedClientId = args.ClientId; + disconnectedMre.Set(); + }); + + // Arrange client + var client = testEnvironment.CreateClient(); + client.UseConnectedHandler((args) => + { + clientAssignedClientId = args.AuthenticateResult.AssignedClientIdentifier; + }); + + // Act + await client.ConnectAsync(new MqttClientOptionsBuilder() + .WithTcpServer("127.0.0.1", testEnvironment.ServerPort) + .WithProtocolVersion(MqttProtocolVersion.V500) + .WithClientId(null) + .Build()); + await client.DisconnectAsync(); + + // Wait for ClientDisconnectedHandler to trigger + disconnectedMre.Wait(500); + + // Assert + Assert.IsNotNull(serverConnectedClientId); + Assert.IsNotNull(serverDisconnectedClientId); + Assert.IsNotNull(clientAssignedClientId); + Assert.AreEqual("test123", serverConnectedClientId); + Assert.AreEqual("test123", serverDisconnectedClientId); + Assert.AreEqual("test123", clientAssignedClientId); + + } + } [TestMethod] public async Task Connect() @@ -297,7 +359,7 @@ namespace MQTTnet.Tests.MQTTv5 [TestMethod] public async Task Publish_And_Receive_New_Properties() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); diff --git a/Tests/MQTTnet.Core.Tests/MQTTv5/Feature_Tests.cs b/Tests/MQTTnet.Core.Tests/MQTTv5/Feature_Tests.cs index 294ee84..6507fff 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTv5/Feature_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MQTTv5/Feature_Tests.cs @@ -13,10 +13,12 @@ namespace MQTTnet.Tests.MQTTv5 [TestClass] public class Feature_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Use_User_Properties() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); diff --git a/Tests/MQTTnet.Core.Tests/MQTTv5/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/MQTTv5/Server_Tests.cs index e3020ce..014bd92 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTv5/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MQTTv5/Server_Tests.cs @@ -11,10 +11,12 @@ namespace MQTTnet.Tests.MQTTv5 [TestClass] public class Server_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Will_Message_Send() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; diff --git a/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs index 0aeea6d..edfe058 100644 --- a/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs @@ -1,10 +1,13 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; using MQTTnet.Client.Connecting; using MQTTnet.Client.Options; +using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; using MQTTnet.Extensions.ManagedClient; using MQTTnet.Server; @@ -15,6 +18,8 @@ namespace MQTTnet.Tests [TestClass] public class ManagedMqttClient_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Drop_New_Messages_On_Full_Queue() { @@ -51,7 +56,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task ManagedClients_Will_Message_Send() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -85,7 +90,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Start_Stop() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var factory = new MqttFactory(); @@ -95,25 +100,24 @@ namespace MQTTnet.Tests var clientOptions = new MqttClientOptionsBuilder() .WithTcpServer("localhost", testEnvironment.ServerPort); - TaskCompletionSource connected = new TaskCompletionSource(); - managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => { connected.SetResult(true);}); + var connected = GetConnectedTask(managedClient); await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder() .WithClientOptions(clientOptions) .Build()); - await connected.Task; + await connected; await managedClient.StopAsync(); Assert.AreEqual(0, (await server.GetClientStatusAsync()).Count); } } - + [TestMethod] public async Task Storage_Queue_Drains() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { testEnvironment.IgnoreClientLogErrors = true; testEnvironment.IgnoreServerLogErrors = true; @@ -127,12 +131,7 @@ namespace MQTTnet.Tests .WithTcpServer("localhost", testEnvironment.ServerPort); var storage = new ManagedMqttClientTestStorage(); - TaskCompletionSource connected = new TaskCompletionSource(); - managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => - { - managedClient.ConnectedHandler = null; - connected.SetResult(true); - }); + var connected = GetConnectedTask(managedClient); await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder() .WithClientOptions(clientOptions) @@ -140,7 +139,7 @@ namespace MQTTnet.Tests .WithAutoReconnectDelay(System.TimeSpan.FromSeconds(5)) .Build()); - await connected.Task; + await connected; await testEnvironment.Server.StopAsync(); @@ -151,17 +150,12 @@ namespace MQTTnet.Tests //in storage at this point (i.e. no waiting). Assert.AreEqual(1, storage.GetMessageCount()); - connected = new TaskCompletionSource(); - managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => - { - managedClient.ConnectedHandler = null; - connected.SetResult(true); - }); + connected = GetConnectedTask(managedClient); await testEnvironment.Server.StartAsync(new MqttServerOptionsBuilder() .WithDefaultEndpointPort(testEnvironment.ServerPort).Build()); - await connected.Task; + await connected; //Wait 500ms here so the client has time to publish the queued message await Task.Delay(500); @@ -171,8 +165,235 @@ namespace MQTTnet.Tests await managedClient.StopAsync(); } } + + [TestMethod] + public async Task Subscriptions_And_Unsubscriptions_Are_Made_And_Reestablished_At_Reconnect() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + var unmanagedClient = testEnvironment.CreateClient(); + var managedClient = await CreateManagedClientAsync(testEnvironment, unmanagedClient); + + var received = SetupReceivingOfMessages(managedClient, 2); + + // Perform some opposing subscriptions and unsubscriptions to verify + // that these conflicting subscriptions are handled correctly + await managedClient.SubscribeAsync("keptSubscribed"); + await managedClient.SubscribeAsync("subscribedThenUnsubscribed"); + + await managedClient.UnsubscribeAsync("subscribedThenUnsubscribed"); + await managedClient.UnsubscribeAsync("unsubscribedThenSubscribed"); + + await managedClient.SubscribeAsync("unsubscribedThenSubscribed"); + + //wait a bit for the subscriptions to become established before the messages are published + await Task.Delay(500); + + var sendingClient = await testEnvironment.ConnectClientAsync(); + + async Task PublishMessages() + { + await sendingClient.PublishAsync("keptSubscribed", new byte[] { 1 }); + await sendingClient.PublishAsync("subscribedThenUnsubscribed", new byte[] { 1 }); + await sendingClient.PublishAsync("unsubscribedThenSubscribed", new byte[] { 1 }); + } + + await PublishMessages(); + + async Task AssertMessagesReceived() + { + var messages = await received; + Assert.AreEqual("keptSubscribed", messages[0].Topic); + Assert.AreEqual("unsubscribedThenSubscribed", messages[1].Topic); + } + + await AssertMessagesReceived(); + + var connected = GetConnectedTask(managedClient); + + await unmanagedClient.DisconnectAsync(); + + // the managed client has to reconnect by itself + await connected; + + // wait a bit so that the managed client can reestablish the subscriptions + await Task.Delay(500); + + received = SetupReceivingOfMessages(managedClient, 2); + + await PublishMessages(); + + // and then the same subscriptions need to exist again + await AssertMessagesReceived(); + } + } + + // This case also serves as a regression test for the previous behavior which re-published + // each and every existing subscriptions with every new subscription that was made + // (causing performance problems and having the visible symptom of retained messages being received again) + [TestMethod] + public async Task Subscriptions_Subscribe_Only_New_Subscriptions() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + var managedClient = await CreateManagedClientAsync(testEnvironment); + + var sendingClient = await testEnvironment.ConnectClientAsync(); + + await managedClient.SubscribeAsync("topic"); + + //wait a bit for the subscription to become established + await Task.Delay(500); + + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + + var messages = await SetupReceivingOfMessages(managedClient, 1); + + Assert.AreEqual(1, messages.Count); + Assert.AreEqual("topic", messages.Single().Topic); + + await managedClient.SubscribeAsync("anotherTopic"); + + await Task.Delay(500); + + // The subscription of the other topic must not trigger a re-subscription of the existing topic + // (and thus renewed receiving of the retained message) + Assert.AreEqual(1, messages.Count); + } + } + + // This case also serves as a regression test for the previous behavior + // that subscriptions were only published at the ConnectionCheckInterval + [TestMethod] + public async Task Subscriptions_Are_Published_Immediately() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + // Use a long connection check interval to verify that the subscriptions + // do not depend on the connection check interval anymore + var connectionCheckInterval = TimeSpan.FromSeconds(10); + var managedClient = await CreateManagedClientAsync(testEnvironment, null, connectionCheckInterval); + var sendingClient = await testEnvironment.ConnectClientAsync(); + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + + await managedClient.SubscribeAsync("topic"); + + var subscribeTime = DateTime.UtcNow; + + var messages = await SetupReceivingOfMessages(managedClient, 1); + + var elapsed = DateTime.UtcNow - subscribeTime; + Assert.IsTrue(elapsed < TimeSpan.FromSeconds(1), $"Subscriptions must be activated immediately, this one took {elapsed}"); + Assert.AreEqual(messages.Single().Topic, "topic"); + } + } + + [TestMethod] + public async Task Subscriptions_Are_Cleared_At_Logout() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + var managedClient = await CreateManagedClientAsync(testEnvironment); + + var sendingClient = await testEnvironment.ConnectClientAsync(); + await sendingClient.PublishAsync(new MqttApplicationMessage + { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + + // Wait a bit for the retained message to be available + await Task.Delay(500); + + await managedClient.SubscribeAsync("topic"); + + await SetupReceivingOfMessages(managedClient, 1); + + await managedClient.StopAsync(); + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost", testEnvironment.ServerPort); + await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder() + .WithClientOptions(clientOptions) + .WithAutoReconnectDelay(TimeSpan.FromSeconds(1)) + .Build()); + + var messages = new List(); + managedClient.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(r => + { + messages.Add(r.ApplicationMessage); + }); + + await Task.Delay(500); + + // After reconnect and then some delay, the retained message must not be received, + // showing that the subscriptions were cleared + Assert.AreEqual(0, messages.Count); + } + } + + private async Task CreateManagedClientAsync( + TestEnvironment testEnvironment, + IMqttClient underlyingClient = null, + TimeSpan? connectionCheckInterval = null) + { + await testEnvironment.StartServerAsync(); + + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost", testEnvironment.ServerPort); + + var managedOptions = new ManagedMqttClientOptionsBuilder() + .WithClientOptions(clientOptions) + .Build(); + + // Use a short connection check interval so that subscription operations are performed quickly + // in order to verify against a previous implementation that performed subscriptions only + // at connection check intervals + managedOptions.ConnectionCheckInterval = connectionCheckInterval ?? TimeSpan.FromSeconds(0.1); + + var managedClient = + new ManagedMqttClient(underlyingClient ?? testEnvironment.CreateClient(), new MqttNetLogger().CreateChildLogger()); + + var connected = GetConnectedTask(managedClient); + + await managedClient.StartAsync(managedOptions); + + await connected; + + return managedClient; + } + + /// + /// Returns a task that will finish when the has connected + /// + private Task GetConnectedTask(ManagedMqttClient managedClient) + { + TaskCompletionSource connected = new TaskCompletionSource(); + managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => + { + managedClient.ConnectedHandler = null; + connected.SetResult(true); + }); + return connected.Task; + } + + /// + /// Returns a task that will return the messages received on + /// when have been received + /// + private Task> SetupReceivingOfMessages(ManagedMqttClient managedClient, int expectedNumberOfMessages) + { + var receivedMessages = new List(); + var allReceived = new TaskCompletionSource>(); + managedClient.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(r => + { + receivedMessages.Add(r.ApplicationMessage); + if (receivedMessages.Count == expectedNumberOfMessages) + { + allReceived.SetResult(receivedMessages); + } + }); + return allReceived.Task; + } } - + public class ManagedMqttClientTestStorage : IManagedMqttClientStorage { private IList _messages = null; diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestClientWrapper.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestClientWrapper.cs new file mode 100644 index 0000000..2500a6f --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestClientWrapper.cs @@ -0,0 +1,94 @@ +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Client; +using MQTTnet.Client.Connecting; +using MQTTnet.Client.Disconnecting; +using MQTTnet.Client.ExtendedAuthenticationExchange; +using MQTTnet.Client.Options; +using MQTTnet.Client.Publishing; +using MQTTnet.Client.Receiving; +using MQTTnet.Client.Subscribing; +using MQTTnet.Client.Unsubscribing; + +namespace MQTTnet.Tests.Mockups +{ + public class TestClientWrapper : IMqttClient + { + public TestClientWrapper(IMqttClient implementation, TestContext testContext) + { + Implementation = implementation; + TestContext = testContext; + } + + public IMqttClient Implementation { get; } + public TestContext TestContext { get; } + + public bool IsConnected => Implementation.IsConnected; + + public IMqttClientOptions Options => Implementation.Options; + + public IMqttClientConnectedHandler ConnectedHandler { get => Implementation.ConnectedHandler; set => Implementation.ConnectedHandler = value; } + public IMqttClientDisconnectedHandler DisconnectedHandler { get => Implementation.DisconnectedHandler; set => Implementation.DisconnectedHandler = value; } + public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get => Implementation.ApplicationMessageReceivedHandler; set => Implementation.ApplicationMessageReceivedHandler = value; } + + public Task ConnectAsync(IMqttClientOptions options, CancellationToken cancellationToken) + { + switch (options) + { + case MqttClientOptionsBuilder builder: + { + var existingClientId = builder.Build().ClientId; + if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) + { + builder.WithClientId(TestContext.TestName + existingClientId); + } + } + break; + case MqttClientOptions op: + { + var existingClientId = op.ClientId; + if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) + { + op.ClientId = TestContext.TestName + existingClientId; + } + } + break; + default: + break; + } + + return Implementation.ConnectAsync(options, cancellationToken); + } + + public Task DisconnectAsync(MqttClientDisconnectOptions options, CancellationToken cancellationToken) + { + return Implementation.DisconnectAsync(options, cancellationToken); + } + + public void Dispose() + { + Implementation.Dispose(); + } + + public Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken) + { + return Implementation.PublishAsync(applicationMessage, cancellationToken); + } + + public Task SendExtendedAuthenticationExchangeDataAsync(MqttExtendedAuthenticationExchangeData data, CancellationToken cancellationToken) + { + return Implementation.SendExtendedAuthenticationExchangeDataAsync(data, cancellationToken); + } + + public Task SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken) + { + return Implementation.SubscribeAsync(options, cancellationToken); + } + + public Task UnsubscribeAsync(MqttClientUnsubscribeOptions options, CancellationToken cancellationToken) + { + return Implementation.UnsubscribeAsync(options, cancellationToken); + } + } +} \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs index 5cda537..7f2dcac 100644 --- a/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestEnvironment.cs @@ -2,14 +2,16 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; using MQTTnet.Client.Options; using MQTTnet.Diagnostics; +using MQTTnet.Internal; using MQTTnet.Server; namespace MQTTnet.Tests.Mockups { - public class TestEnvironment : IDisposable + public class TestEnvironment : Disposable { private readonly MqttFactory _mqttFactory = new MqttFactory(); private readonly List _clients = new List(); @@ -33,7 +35,9 @@ namespace MQTTnet.Tests.Mockups public IMqttNetLogger ClientLogger => _clientLogger; - public TestEnvironment() + public TestContext TestContext { get; } + + public TestEnvironment(TestContext testContext) { _serverLogger.LogMessagePublished += (s, e) => { @@ -56,13 +60,14 @@ namespace MQTTnet.Tests.Mockups } } }; + TestContext = testContext; } public IMqttClient CreateClient() { var client = _mqttFactory.CreateMqttClient(_clientLogger); _clients.Add(client); - return client; + return new TestClientWrapper(client, TestContext); } public Task StartServerAsync() @@ -77,7 +82,7 @@ namespace MQTTnet.Tests.Mockups throw new InvalidOperationException("Server already started."); } - Server = _mqttFactory.CreateMqttServer(_serverLogger); + Server = new TestServerWrapper(_mqttFactory.CreateMqttServer(_serverLogger), TestContext, this); await Server.StartAsync(options.WithDefaultEndpointPort(ServerPort).Build()); return Server; @@ -85,7 +90,7 @@ namespace MQTTnet.Tests.Mockups public Task ConnectClientAsync() { - return ConnectClientAsync(new MqttClientOptionsBuilder()); + return ConnectClientAsync(new MqttClientOptionsBuilder() ); } public async Task ConnectClientAsync(MqttClientOptionsBuilder options) @@ -127,21 +132,25 @@ namespace MQTTnet.Tests.Mockups } } - public void Dispose() + protected override void Dispose(bool disposing) { - foreach (var mqttClient in _clients) + if (disposing) { - mqttClient?.Dispose(); - } + foreach (var mqttClient in _clients) + { + mqttClient?.Dispose(); + } - Server?.StopAsync().GetAwaiter().GetResult(); + Server?.StopAsync().GetAwaiter().GetResult(); - ThrowIfLogErrors(); + ThrowIfLogErrors(); - if (_exceptions.Any()) - { - throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); + if (_exceptions.Any()) + { + throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); + } } + base.Dispose(disposing); } public void TrackException(Exception exception) diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestServerWrapper.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestServerWrapper.cs new file mode 100644 index 0000000..f990197 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestServerWrapper.cs @@ -0,0 +1,108 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Client.Publishing; +using MQTTnet.Client.Receiving; +using MQTTnet.Server; +using MQTTnet.Server.Status; + +namespace MQTTnet.Tests.Mockups +{ + public class TestServerWrapper : IMqttServer + { + public TestServerWrapper(IMqttServer implementation, TestContext testContext, TestEnvironment testEnvironment) + { + Implementation = implementation; + TestContext = testContext; + TestEnvironment = testEnvironment; + } + + public IMqttServer Implementation { get; } + public TestContext TestContext { get; } + public TestEnvironment TestEnvironment { get; } + public IMqttServerStartedHandler StartedHandler { get => Implementation.StartedHandler; set => Implementation.StartedHandler = value; } + public IMqttServerStoppedHandler StoppedHandler { get => Implementation.StoppedHandler; set => Implementation.StoppedHandler = value; } + public IMqttServerClientConnectedHandler ClientConnectedHandler { get => Implementation.ClientConnectedHandler; set => Implementation.ClientConnectedHandler = value; } + public IMqttServerClientDisconnectedHandler ClientDisconnectedHandler { get => Implementation.ClientDisconnectedHandler; set => Implementation.ClientDisconnectedHandler = value; } + public IMqttServerClientSubscribedTopicHandler ClientSubscribedTopicHandler { get => Implementation.ClientSubscribedTopicHandler; set => Implementation.ClientSubscribedTopicHandler = value; } + public IMqttServerClientUnsubscribedTopicHandler ClientUnsubscribedTopicHandler { get => Implementation.ClientUnsubscribedTopicHandler; set => Implementation.ClientUnsubscribedTopicHandler = value; } + + public IMqttServerOptions Options => Implementation.Options; + + public IMqttApplicationMessageReceivedHandler ApplicationMessageReceivedHandler { get => Implementation.ApplicationMessageReceivedHandler; set => Implementation.ApplicationMessageReceivedHandler = value; } + + public Task ClearRetainedApplicationMessagesAsync() + { + return Implementation.ClearRetainedApplicationMessagesAsync(); + } + + public Task> GetClientStatusAsync() + { + return Implementation.GetClientStatusAsync(); + } + + public Task> GetRetainedApplicationMessagesAsync() + { + return Implementation.GetRetainedApplicationMessagesAsync(); + } + + public Task> GetSessionStatusAsync() + { + return Implementation.GetSessionStatusAsync(); + } + + public Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken) + { + return Implementation.PublishAsync(applicationMessage, cancellationToken); + } + + public Task StartAsync(IMqttServerOptions options) + { + switch (options) + { + case MqttServerOptionsBuilder builder: + if (builder.Build().ConnectionValidator == null) + { + builder.WithConnectionValidator(ConnectionValidator); + } + break; + case MqttServerOptions op: + if (op.ConnectionValidator == null) + { + op.ConnectionValidator = new MqttServerConnectionValidatorDelegate(ConnectionValidator); + } + break; + default: + break; + } + + return Implementation.StartAsync(options); + } + + public void ConnectionValidator(MqttConnectionValidatorContext ctx) + { + if (!ctx.ClientId.StartsWith(TestContext.TestName)) + { + TestEnvironment.TrackException(new InvalidOperationException($"invalid client connected '{ctx.ClientId}'")); + ctx.ReasonCode = Protocol.MqttConnectReasonCode.ClientIdentifierNotValid; + } + } + + public Task StopAsync() + { + return Implementation.StopAsync(); + } + + public Task SubscribeAsync(string clientId, ICollection topicFilters) + { + return Implementation.SubscribeAsync(clientId, topicFilters); + } + + public Task UnsubscribeAsync(string clientId, ICollection topicFilters) + { + return Implementation.UnsubscribeAsync(clientId, topicFilters); + } + } +} \ No newline at end of file diff --git a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs index 51d1753..b8f4fbc 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs @@ -20,10 +20,12 @@ namespace MQTTnet.Tests [TestClass] public class Client_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Send_Reply_In_Message_Handler_For_Same_Client() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); var client = await testEnvironment.ConnectClientAsync(); @@ -57,7 +59,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Send_Reply_In_Message_Handler() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); var client1 = await testEnvironment.ConnectClientAsync(); @@ -89,7 +91,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Reconnect() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); var client = await testEnvironment.ConnectClientAsync(); @@ -112,7 +114,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Reconnect_While_Server_Offline() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { testEnvironment.IgnoreClientLogErrors = true; @@ -149,7 +151,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Reconnect_From_Disconnected_Event() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { testEnvironment.IgnoreClientLogErrors = true; @@ -189,7 +191,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task PacketIdentifier_In_Publish_Result() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); var client = await testEnvironment.ConnectClientAsync(); @@ -235,10 +237,40 @@ namespace MQTTnet.Tests } } + [TestMethod] + public async Task ConnectTimeout_Throws_Exception() + { + var factory = new MqttFactory(); + using (var client = factory.CreateMqttClient()) + { + bool disconnectHandlerCalled = false; + try + { + client.DisconnectedHandler = new MqttClientDisconnectedHandlerDelegate(args => + { + disconnectHandlerCalled = true; + }); + + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("1.2.3.4").Build()); + + Assert.Fail("Must fail!"); + } + catch (Exception exception) + { + Assert.IsNotNull(exception); + Assert.IsInstanceOfType(exception, typeof(MqttCommunicationException)); + //Assert.IsInstanceOfType(exception.InnerException, typeof(SocketException)); + } + + await Task.Delay(100); // disconnected handler is called async + Assert.IsTrue(disconnectHandlerCalled); + } + } + [TestMethod] public async Task Fire_Disconnected_Event_On_Server_Shutdown() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); var client = await testEnvironment.ConnectClientAsync(); @@ -290,7 +322,7 @@ namespace MQTTnet.Tests // is an issue). const int MessagesCount = 50; - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -330,7 +362,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Send_Reply_For_Any_Received_Message() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -374,7 +406,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Publish_With_Correct_Retain_Flag() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -405,7 +437,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Subscribe_In_Callback_Events() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -444,7 +476,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Message_Send_Retry() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { testEnvironment.IgnoreClientLogErrors = true; testEnvironment.IgnoreServerLogErrors = true; @@ -488,7 +520,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task NoConnectedHandler_Connect_DoesNotThrowException() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -501,7 +533,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task NoDisconnectedHandler_Disconnect_DoesNotThrowException() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); var client = await testEnvironment.ConnectClientAsync(); @@ -516,7 +548,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Frequent_Connects() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -560,7 +592,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task No_Payload() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); diff --git a/Tests/MQTTnet.Core.Tests/MqttFactory_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttFactory_Tests.cs index 75bda9b..eb7aba3 100644 --- a/Tests/MQTTnet.Core.Tests/MqttFactory_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttFactory_Tests.cs @@ -18,7 +18,7 @@ namespace MQTTnet.Tests //This test compares //1. correct logID string logId = "logId"; - bool invalidLogIdOccured = false; + string invalidLogId = null; //2. if the total log calls are the same for global and local int globalLogCount = 0; @@ -31,7 +31,7 @@ namespace MQTTnet.Tests { if (logId != e.TraceMessage.LogId) { - invalidLogIdOccured = true; + invalidLogId = e.TraceMessage.LogId; } Interlocked.Increment(ref globalLogCount); }); @@ -42,7 +42,7 @@ namespace MQTTnet.Tests { if (logId != e.TraceMessage.LogId) { - invalidLogIdOccured = true; + invalidLogId = e.TraceMessage.LogId; } Interlocked.Increment(ref localLogCount); }; @@ -69,7 +69,7 @@ namespace MQTTnet.Tests MqttNetGlobalLogger.LogMessagePublished -= globalLog; } - Assert.IsFalse(invalidLogIdOccured); + Assert.IsNull(invalidLogId); Assert.AreNotEqual(0, globalLogCount); Assert.AreEqual(globalLogCount, localLogCount); } diff --git a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs index 436d2d1..a4b0ca7 100644 --- a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs @@ -28,14 +28,14 @@ namespace MQTTnet.Tests { while (!ct.IsCancellationRequested) { - var client = await serverSocket.AcceptAsync(); + var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket); var data = new byte[] { 128 }; - await client.SendAsync(new ArraySegment(data), SocketFlags.None); + await PlatformAbstractionLayer.SendAsync(client, new ArraySegment(data), SocketFlags.None); } }, ct.Token); var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await clientSocket.ConnectAsync(IPAddress.Loopback, 50001); + await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001); await Task.Delay(100, ct.Token); diff --git a/Tests/MQTTnet.Core.Tests/RPC_Tests.cs b/Tests/MQTTnet.Core.Tests/RPC_Tests.cs index a420697..947c104 100644 --- a/Tests/MQTTnet.Core.Tests/RPC_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/RPC_Tests.cs @@ -18,6 +18,8 @@ namespace MQTTnet.Tests [TestClass] public class RPC_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public Task Execute_Success_With_QoS_0() { @@ -58,7 +60,7 @@ namespace MQTTnet.Tests [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_Timeout() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -73,7 +75,7 @@ namespace MQTTnet.Tests [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_With_Custom_Topic_Names() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -86,7 +88,7 @@ namespace MQTTnet.Tests private async Task Execute_Success(MqttQualityOfServiceLevel qosLevel, MqttProtocolVersion protocolVersion) { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); var responseSender = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithProtocolVersion(protocolVersion)); diff --git a/Tests/MQTTnet.Core.Tests/RoundtripTime_Tests.cs b/Tests/MQTTnet.Core.Tests/RoundtripTime_Tests.cs index 02055a8..377d142 100644 --- a/Tests/MQTTnet.Core.Tests/RoundtripTime_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/RoundtripTime_Tests.cs @@ -11,10 +11,12 @@ namespace MQTTnet.Tests [TestClass] public class RoundtripTime_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Round_Trip_Time() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); var receiverClient = await testEnvironment.ConnectClientAsync(); diff --git a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs index 8111bd0..93c1d55 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs @@ -6,16 +6,19 @@ using MQTTnet.Tests.Mockups; using MQTTnet.Client; using MQTTnet.Protocol; using MQTTnet.Server; +using System.Threading; namespace MQTTnet.Tests { [TestClass] public class Server_Status_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Show_Client_And_Session_Statistics() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); @@ -30,8 +33,8 @@ namespace MQTTnet.Tests Assert.AreEqual(2, clientStatus.Count); Assert.AreEqual(2, sessionStatus.Count); - Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1")); - Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client2")); + Assert.IsTrue(clientStatus.Any(s => s.ClientId == c1.Options.ClientId)); + Assert.IsTrue(clientStatus.Any(s => s.ClientId == c2.Options.ClientId)); await c1.DisconnectAsync(); await c2.DisconnectAsync(); @@ -49,19 +52,19 @@ namespace MQTTnet.Tests [TestMethod] public async Task Disconnect_Client() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1")); - await Task.Delay(500); + await Task.Delay(1000); var clientStatus = await server.GetClientStatusAsync(); - + Assert.AreEqual(1, clientStatus.Count); - Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1")); + Assert.IsTrue(clientStatus.Any(s => s.ClientId == c1.Options.ClientId)); await clientStatus.First().DisconnectAsync(); @@ -78,7 +81,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Keep_Persistent_Session() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); @@ -110,7 +113,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Track_Sent_Application_Messages() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); @@ -131,7 +134,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Track_Sent_Packets() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs index b2b3b70..3ebdaa6 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -13,6 +13,7 @@ using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; using MQTTnet.Client.Receiving; using MQTTnet.Client.Subscribing; +using MQTTnet.Implementations; using MQTTnet.Protocol; using MQTTnet.Server; using MQTTnet.Tests.Mockups; @@ -22,10 +23,12 @@ namespace MQTTnet.Tests [TestClass] public class Server_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Use_Empty_Client_ID() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -51,7 +54,8 @@ namespace MQTTnet.Tests MqttQualityOfServiceLevel.AtMostOnce, "A/B/C", MqttQualityOfServiceLevel.AtMostOnce, - 1); + 1, + TestContext); } [TestMethod] @@ -62,7 +66,8 @@ namespace MQTTnet.Tests MqttQualityOfServiceLevel.AtLeastOnce, "A/B/C", MqttQualityOfServiceLevel.AtLeastOnce, - 1); + 1, + TestContext); } [TestMethod] @@ -73,13 +78,14 @@ namespace MQTTnet.Tests MqttQualityOfServiceLevel.ExactlyOnce, "A/B/C", MqttQualityOfServiceLevel.ExactlyOnce, - 1); + 1, + TestContext); } [TestMethod] public async Task Use_Clean_Session() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -93,7 +99,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Will_Message_Do_Not_Send() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -119,7 +125,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Will_Message_Send() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -145,7 +151,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Intercept_Subscription() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithSubscriptionInterceptor( c => @@ -184,7 +190,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Subscribe_Unsubscribe() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -204,7 +210,7 @@ namespace MQTTnet.Tests var subscribeEventCalled = false; server.ClientSubscribedTopicHandler = new MqttServerClientSubscribedHandlerDelegate(e => { - subscribeEventCalled = e.TopicFilter.Topic == "a" && e.ClientId == "c1"; + subscribeEventCalled = e.TopicFilter.Topic == "a" && e.ClientId == c1.Options.ClientId; }); await c1.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); @@ -218,7 +224,7 @@ namespace MQTTnet.Tests var unsubscribeEventCalled = false; server.ClientUnsubscribedTopicHandler = new MqttServerClientUnsubscribedTopicHandlerDelegate(e => { - unsubscribeEventCalled = e.TopicFilter == "a" && e.ClientId == "c1"; + unsubscribeEventCalled = e.TopicFilter == "a" && e.ClientId == c1.Options.ClientId; }); await c1.UnsubscribeAsync("a"); @@ -238,7 +244,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Subscribe_Multiple_In_Single_Request() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -271,7 +277,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Subscribe_Multiple_In_Multiple_Request() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -310,7 +316,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Publish_From_Server() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); @@ -336,7 +342,7 @@ namespace MQTTnet.Tests var receivedMessagesCount = 0; var locked = new object(); - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -378,7 +384,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Session_Takeover() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -400,7 +406,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task No_Messages_If_No_Subscription() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -433,7 +439,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Set_Subscription_At_Server() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); server.ClientConnectedHandler = new MqttServerClientConnectedHandlerDelegate(e => server.SubscribeAsync(e.ClientId, "topic1")); @@ -464,7 +470,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Shutdown_Disconnects_Clients_Gracefully() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); @@ -486,7 +492,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Handle_Clean_Disconnect() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); @@ -515,7 +521,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Client_Disconnect_Without_Errors() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { bool clientWasConnected; @@ -546,7 +552,7 @@ namespace MQTTnet.Tests { const int ClientCount = 50; - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); @@ -598,7 +604,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Retained_Messages_Flow() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); @@ -635,7 +641,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Receive_No_Retained_Message_After_Subscribe() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -658,7 +664,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Receive_Retained_Message_After_Subscribe() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(); @@ -689,7 +695,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Clear_Retained_Message_With_Empty_Payload() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -717,7 +723,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Clear_Retained_Message_With_Null_Payload() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var receivedMessagesCount = 0; @@ -745,7 +751,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Intercept_Application_Message() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync( new MqttServerOptionsBuilder().WithApplicationMessageInterceptor( @@ -768,7 +774,7 @@ namespace MQTTnet.Tests { var serverStorage = new TestServerStorage(); - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithStorage(serverStorage)); @@ -785,7 +791,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Publish_After_Client_Connects() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(); server.UseClientConnectedHandler(async e => @@ -818,7 +824,7 @@ namespace MQTTnet.Tests context.ApplicationMessage.Payload = Encoding.ASCII.GetBytes("extended"); } - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithApplicationMessageInterceptor(Interceptor)); @@ -844,7 +850,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Send_Long_Body() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { const int PayloadSizeInMB = 30; const int CharCount = PayloadSizeInMB * 1024 * 1024; @@ -889,38 +895,128 @@ namespace MQTTnet.Tests { var serverOptions = new MqttServerOptionsBuilder().WithConnectionValidator(context => { - context.ReturnCode = MqttConnectReturnCode.ConnectionRefusedNotAuthorized; + context.ReasonCode = MqttConnectReasonCode.NotAuthorized; }); - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { testEnvironment.IgnoreClientLogErrors = true; await testEnvironment.StartServerAsync(serverOptions); - try + + var connectingFailedException = await Assert.ThrowsExceptionAsync(() => testEnvironment.ConnectClientAsync()); + Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode); + } + } + + + private Dictionary _connected; + private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs) + { + if (_connected.ContainsKey(eventArgs.ClientId)) + { + eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; + return; + } + _connected[eventArgs.ClientId] = true; + eventArgs.ReasonCode = MqttConnectReasonCode.Success; + return; + } + + [TestMethod] + public async Task Same_Client_Id_Refuse_Connection() + { + using (var testEnvironment = new TestEnvironment(TestContext)) + { + testEnvironment.IgnoreClientLogErrors = true; + + _connected = new Dictionary(); + var options = new MqttServerOptionsBuilder(); + options.WithConnectionValidator(e => ConnectionValidationHandler(e)); + var server = await testEnvironment.StartServerAsync(options); + + var events = new List(); + + server.ClientConnectedHandler = new MqttServerClientConnectedHandlerDelegate(_ => { - await testEnvironment.ConnectClientAsync(); - Assert.Fail("An exception should be raised."); - } - catch (Exception exception) + lock (events) + { + events.Add("c"); + } + }); + + server.ClientDisconnectedHandler = new MqttServerClientDisconnectedHandlerDelegate(_ => + { + lock (events) + { + events.Add("d"); + } + }); + + var clientOptions = new MqttClientOptionsBuilder() + .WithClientId("same_id"); + + // c + var c1 = await testEnvironment.ConnectClientAsync(clientOptions); + + c1.UseDisconnectedHandler(_ => { - if (exception is MqttConnectingFailedException connectingFailedException) + lock (events) { - Assert.AreEqual(MqttClientConnectResultCode.NotAuthorized, connectingFailedException.ResultCode); + events.Add("x"); } - else + }); + + + c1.UseApplicationMessageReceivedHandler(_ => + { + lock (events) { - Assert.Fail("Wrong exception."); + events.Add("r"); } + + }); + + c1.SubscribeAsync("topic").Wait(); + + await Task.Delay(500); + + c1.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + var flow = string.Join(string.Empty, events); + Assert.AreEqual("cr", flow); + + try + { + await testEnvironment.ConnectClientAsync(clientOptions); + Assert.Fail("same id connection is expected to fail"); + } + catch + { + //same id connection is expected to fail } + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cr", flow); + + c1.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("crr", flow); } } [TestMethod] public async Task Same_Client_Id_Connect_Disconnect_Event_Order() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); @@ -956,17 +1052,40 @@ namespace MQTTnet.Tests // dc var c2 = await testEnvironment.ConnectClientAsync(clientOptions); + c2.UseApplicationMessageReceivedHandler(_ => + { + lock (events) + { + events.Add("r"); + } + + }); + c2.SubscribeAsync("topic").Wait(); + await Task.Delay(500); flow = string.Join(string.Empty, events); Assert.AreEqual("cdc", flow); + // r + c2.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cdcr", flow); + + // nothing + + Assert.AreEqual(false, c1.IsConnected); await c1.DisconnectAsync(); + Assert.AreEqual (false, c1.IsConnected); await Task.Delay(500); // d + Assert.AreEqual(true, c2.IsConnected); await c2.DisconnectAsync(); await Task.Delay(500); @@ -974,14 +1093,14 @@ namespace MQTTnet.Tests await server.StopAsync(); flow = string.Join(string.Empty, events); - Assert.AreEqual("cdcd", flow); + Assert.AreEqual("cdcrd", flow); } } [TestMethod] public async Task Remove_Session() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder()); @@ -1000,7 +1119,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Stop_And_Restart() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { testEnvironment.IgnoreClientLogErrors = true; @@ -1022,23 +1141,23 @@ namespace MQTTnet.Tests await testEnvironment.ConnectClientAsync(); } } - + [TestMethod] public async Task Close_Idle_Connection() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await client.ConnectAsync("localhost", testEnvironment.ServerPort); + await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); // Don't send anything. The server should close the connection. await Task.Delay(TimeSpan.FromSeconds(3)); try { - var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment(new byte[10]), SocketFlags.Partial); if (receivedBytes == 0) { return; @@ -1055,14 +1174,14 @@ namespace MQTTnet.Tests [TestMethod] public async Task Send_Garbage() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state // forever. This is security related. var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await client.ConnectAsync("localhost", testEnvironment.ServerPort); + await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); var buffer = Encoding.UTF8.GetBytes("Garbage"); client.Send(buffer, buffer.Length, SocketFlags.None); @@ -1071,7 +1190,7 @@ namespace MQTTnet.Tests try { - var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment(new byte[10]), SocketFlags.Partial); if (receivedBytes == 0) { return; @@ -1088,7 +1207,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Do_Not_Send_Retained_Messages_For_Denied_Subscription() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithSubscriptionInterceptor(c => { @@ -1132,7 +1251,7 @@ namespace MQTTnet.Tests [TestMethod] public async Task Collect_Messages_In_Disconnected_Session() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var server = await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithPersistentSessions()); @@ -1159,7 +1278,7 @@ namespace MQTTnet.Tests Assert.AreEqual(0, clientStatus.Count); Assert.AreEqual(2, sessionStatus.Count); - Assert.AreEqual(3, sessionStatus.First(s => s.ClientId == "a").PendingApplicationMessagesCount); + Assert.AreEqual(3, sessionStatus.First(s => s.ClientId == client1.Options.ClientId).PendingApplicationMessagesCount); } } @@ -1168,9 +1287,10 @@ namespace MQTTnet.Tests MqttQualityOfServiceLevel qualityOfServiceLevel, string topicFilter, MqttQualityOfServiceLevel filterQualityOfServiceLevel, - int expectedReceivedMessagesCount) + int expectedReceivedMessagesCount, + TestContext testContext) { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(testContext)) { var receivedMessagesCount = 0; diff --git a/Tests/MQTTnet.Core.Tests/Session_Tests.cs b/Tests/MQTTnet.Core.Tests/Session_Tests.cs index d06bd4e..073f272 100644 --- a/Tests/MQTTnet.Core.Tests/Session_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Session_Tests.cs @@ -11,10 +11,12 @@ namespace MQTTnet.Tests [TestClass] public class Session_Tests { + public TestContext TestContext { get; set; } + [TestMethod] public async Task Set_Session_Item() { - using (var testEnvironment = new TestEnvironment()) + using (var testEnvironment = new TestEnvironment(TestContext)) { var serverOptions = new MqttServerOptionsBuilder() .WithConnectionValidator(delegate (MqttConnectionValidatorContext context) diff --git a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj index 2139c7f..5c0aacd 100644 --- a/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj +++ b/Tests/MQTTnet.TestApp.AspNetCore2/MQTTnet.TestApp.AspNetCore2.csproj @@ -11,14 +11,11 @@ - - - + - diff --git a/Tests/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj b/Tests/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj index d60256e..4fffc33 100644 --- a/Tests/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj +++ b/Tests/MQTTnet.TestApp.NetCore/MQTTnet.TestApp.NetCore.csproj @@ -12,7 +12,7 @@ - + diff --git a/Tests/MQTTnet.TestApp.NetCore/ManagedClientTest.cs b/Tests/MQTTnet.TestApp.NetCore/ManagedClientTest.cs index dc7d925..c4fd68f 100644 --- a/Tests/MQTTnet.TestApp.NetCore/ManagedClientTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/ManagedClientTest.cs @@ -40,11 +40,11 @@ namespace MQTTnet.TestApp.NetCore Console.WriteLine(">> RECEIVED: " + e.ApplicationMessage.Topic); }); - await managedClient.PublishAsync(builder => builder.WithTopic("Step").WithPayload("1")); - await managedClient.PublishAsync(builder => builder.WithTopic("Step").WithPayload("2").WithAtLeastOnceQoS()); - await managedClient.StartAsync(options); + await managedClient.PublishAsync(builder => builder.WithTopic("Step").WithPayload("1")); + await managedClient.PublishAsync(builder => builder.WithTopic("Step").WithPayload("2").WithAtLeastOnceQoS()); + await managedClient.SubscribeAsync(new TopicFilter { Topic = "xyz", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); await managedClient.SubscribeAsync(new TopicFilter { Topic = "abc", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); diff --git a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs index bd62671..b0d6534 100644 --- a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs @@ -30,7 +30,7 @@ namespace MQTTnet.TestApp.NetCore { if (p.Username != "USER" || p.Password != "PASS") { - p.ReturnCode = MqttConnectReturnCode.ConnectionRefusedBadUsernameOrPassword; + p.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; } } }), diff --git a/appveyor.yml b/appveyor.yml index 498a7bf..ca539e1 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -14,5 +14,5 @@ build: verbosity: minimal test_script: -- cmd: dotnet vstest "%APPVEYOR_BUILD_FOLDER%\Tests\MQTTnet.Core.Tests\bin\Release\netcoreapp2.1\MQTTnet.Tests.dll" -- cmd: dotnet vstest "%APPVEYOR_BUILD_FOLDER%\Tests\MQTTnet.AspNetCore.Tests\bin\Release\netcoreapp2.1\MQTTnet.AspNetCore.Tests.dll" +- cmd: dotnet vstest "%APPVEYOR_BUILD_FOLDER%\Tests\MQTTnet.Core.Tests\bin\Release\netcoreapp2.2\MQTTnet.Tests.dll" +- cmd: dotnet vstest "%APPVEYOR_BUILD_FOLDER%\Tests\MQTTnet.AspNetCore.Tests\bin\Release\netcoreapp2.2\MQTTnet.AspNetCore.Tests.dll"