diff --git a/.gitignore b/.gitignore index 93a9b60..bc1ad4a 100644 --- a/.gitignore +++ b/.gitignore @@ -292,3 +292,5 @@ __pycache__/ *.map /Tests/MQTTnet.TestApp.NetCore/RetainedMessages.json + +Build/NuGet/ diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index c9d4016..6aafdc6 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -11,7 +11,16 @@ false MQTTnet is a high performance .NET library for MQTT based communication. It provides a MQTT client and a MQTT server (broker) and supports v3.1.0, v3.1.1 and v5.0.0 of the MQTT protocol. -* [Server] Moved new socket options to TCP options to avoid incompatibility with Linux hosts. +* [Core] Nuget packages with symbols are now also published to improve debugging. +* [Core] Improve task handling (thanks to @mwinterb) +* [ManagedClient] Fix a race condition in the message storage (thanks to @PaulFake). +* [Server] Added items dictionary to client session in order to share data across interceptors as along as the session exists. +* [Server] Exposed CONNECT packet properties in Application Message and Subscription interceptor. +* [Server] Fixed: Sending Large packets with AspnetCore based connection throws System.ArgumentException. +* [Server] Fixed wrong usage of socket option _NoDelay_. +* [Server] Added remote certificate validation callback (thanks to @rudacs). +* [Server] Add support for certificate passwords (thanks to @cslutgen). +* [MQTTnet.Server] Added REST API for publishing basic messages. Copyright Christian Kratky 2016-2019 MQTT Message Queue Telemetry Transport MQTTClient MQTTServer Server MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Hardware Arduino Sensor Actuator M2M ESP Smart Home Cities Automation Xamarin diff --git a/Build/build.ps1 b/Build/build.ps1 index f5f9975..33f2767 100644 --- a/Build/build.ps1 +++ b/Build/build.ps1 @@ -59,12 +59,12 @@ Copy-Item MQTTnet.Extensions.WebSocket4Net.nuspec -Destination MQTTnet.Extension (Get-Content MQTTnet.Extensions.WebSocket4Net.nuspec) -replace '\$nugetVersion', $nugetVersion | Set-Content MQTTnet.Extensions.WebSocket4Net.nuspec New-Item -ItemType Directory -Force -Path .\NuGet -.\nuget.exe pack MQTTnet.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion -.\nuget.exe pack MQTTnet.NETStandard.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion -.\nuget.exe pack MQTTnet.AspNetCore.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion -.\nuget.exe pack MQTTnet.Extensions.Rpc.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion -.\nuget.exe pack MQTTnet.Extensions.ManagedClient.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion -.\nuget.exe pack MQTTnet.Extensions.WebSocket4Net.nuspec -Verbosity detailed -Symbols -OutputDir "NuGet" -Version $nugetVersion +.\nuget.exe pack MQTTnet.nuspec -Verbosity detailed -Symbols -SymbolPackageFormat snupkg -OutputDir "NuGet" -Version $nugetVersion +.\nuget.exe pack MQTTnet.NETStandard.nuspec -Verbosity detailed -Symbols -SymbolPackageFormat snupkg -OutputDir "NuGet" -Version $nugetVersion +.\nuget.exe pack MQTTnet.AspNetCore.nuspec -Verbosity detailed -Symbols -SymbolPackageFormat snupkg -OutputDir "NuGet" -Version $nugetVersion +.\nuget.exe pack MQTTnet.Extensions.Rpc.nuspec -Verbosity detailed -Symbols -SymbolPackageFormat snupkg -OutputDir "NuGet" -Version $nugetVersion +.\nuget.exe pack MQTTnet.Extensions.ManagedClient.nuspec -Verbosity detailed -Symbols -SymbolPackageFormat snupkg -OutputDir "NuGet" -Version $nugetVersion +.\nuget.exe pack MQTTnet.Extensions.WebSocket4Net.nuspec -Verbosity detailed -Symbols -SymbolPackageFormat snupkg -OutputDir "NuGet" -Version $nugetVersion Move-Item MQTTnet.AspNetCore.nuspec.old -Destination MQTTnet.AspNetCore.nuspec -Force Move-Item MQTTnet.Extensions.Rpc.nuspec.old -Destination MQTTnet.Extensions.Rpc.nuspec -Force diff --git a/Build/upload.ps1 b/Build/upload.ps1 index adb7c6d..794fe87 100644 --- a/Build/upload.ps1 +++ b/Build/upload.ps1 @@ -7,7 +7,7 @@ foreach ($file in $files) { Write-Host "Uploading: " $file - .\nuget.exe push $file.Fullname $apiKey -NoSymbols -Source https://api.nuget.org/v3/index.json + .\nuget.exe push $file.Fullname $apiKey -Source https://api.nuget.org/v3/index.json } Remove-Item "nuget.exe" -Force -Recurse -ErrorAction SilentlyContinue \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index f50d85d..2aa9842 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -162,11 +162,12 @@ namespace MQTTnet.AspNetCore var buffer = formatter.Encode(packet); var msg = buffer.AsMemory(); var output = _output; - msg.CopyTo(output.GetMemory(msg.Length)); - BytesSent += msg.Length; + var result = await output.WriteAsync(msg, cancellationToken).ConfigureAwait(false); + if (result.IsCompleted) + { + BytesSent += msg.Length; + } PacketFormatterAdapter.FreeBuffer(); - output.Advance(msg.Length); - await output.FlushAsync().ConfigureAwait(false); } finally { diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs index 74d148d..f53a8dd 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs @@ -24,6 +24,8 @@ namespace MQTTnet.Extensions.ManagedClient private readonly IMqttClient _mqttClient; private readonly IMqttNetChildLogger _logger; + + private readonly AsyncLock _messageQueueLock = new AsyncLock(); private CancellationTokenSource _connectionCancellationToken; private CancellationTokenSource _publishingCancellationToken; @@ -147,7 +149,7 @@ namespace MQTTnet.Extensions.ManagedClient try { - lock (_messageQueue) + using (await _messageQueueLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { if (_messageQueue.Count >= Options.MaxPendingMessages) { @@ -167,6 +169,16 @@ namespace MQTTnet.Extensions.ManagedClient } _messageQueue.Enqueue(applicationMessage); + + if (_storageManager != null) + { + if (removedMessage != null) + { + await _storageManager.RemoveAsync(removedMessage).ConfigureAwait(false); + } + + await _storageManager.AddAsync(applicationMessage).ConfigureAwait(false); + } } } finally @@ -181,16 +193,6 @@ namespace MQTTnet.Extensions.ManagedClient } } - - if (_storageManager != null) - { - if (removedMessage != null) - { - await _storageManager.RemoveAsync(removedMessage).ConfigureAwait(false); - } - - await _storageManager.AddAsync(applicationMessage).ConfigureAwait(false); - } } public Task SubscribeAsync(IEnumerable topicFilters) @@ -362,7 +364,7 @@ namespace MQTTnet.Extensions.ManagedClient } catch (Exception exception) { - _logger.Error(exception, "Unhandled exception while publishing queued application messages."); + _logger.Error(exception, "Error while publishing queued application messages."); } finally { @@ -377,7 +379,7 @@ namespace MQTTnet.Extensions.ManagedClient { await _mqttClient.PublishAsync(message.ApplicationMessage).ConfigureAwait(false); - lock (_messageQueue) //lock to avoid conflict with this.PublishAsync + using (await _messageQueueLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) //lock to avoid conflict with this.PublishAsync { // While publishing this message, this.PublishAsync could have booted this // message off the queue to make room for another (when using a cap @@ -386,11 +388,11 @@ namespace MQTTnet.Extensions.ManagedClient // it from the queue. If not, that means this.PublishAsync has already // removed it, in which case we don't want to do anything. _messageQueue.RemoveFirst(i => i.Id.Equals(message.Id)); - } - - if (_storageManager != null) - { - await _storageManager.RemoveAsync(message).ConfigureAwait(false); + + if (_storageManager != null) + { + await _storageManager.RemoveAsync(message).ConfigureAwait(false); + } } } catch (MqttCommunicationException exception) @@ -408,21 +410,21 @@ namespace MQTTnet.Extensions.ManagedClient //contradict the expected behavior of QoS 1 and 2, that's also true //for the usage of a message queue cap, so it's still consistent //with prior behavior in that way. - lock (_messageQueue) //lock to avoid conflict with this.PublishAsync + using (await _messageQueueLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) //lock to avoid conflict with this.PublishAsync { _messageQueue.RemoveFirst(i => i.Id.Equals(message.Id)); - } - - if (_storageManager != null) - { - await _storageManager.RemoveAsync(message).ConfigureAwait(false); + + if (_storageManager != null) + { + await _storageManager.RemoveAsync(message).ConfigureAwait(false); + } } } } catch (Exception exception) { transmitException = exception; - _logger.Error(exception, $"Unhandled exception while publishing application message ({message.Id})."); + _logger.Error(exception, $"Error while publishing application message ({message.Id})."); } finally { @@ -533,4 +535,4 @@ namespace MQTTnet.Extensions.ManagedClient _connectionCancellationToken = null; } } -} \ No newline at end of file +} diff --git a/Source/MQTTnet.Server/Controllers/ClientsController.cs b/Source/MQTTnet.Server/Controllers/ClientsController.cs index 6898375..bd9795a 100644 --- a/Source/MQTTnet.Server/Controllers/ClientsController.cs +++ b/Source/MQTTnet.Server/Controllers/ClientsController.cs @@ -13,7 +13,7 @@ namespace MQTTnet.Server.Controllers { [Authorize] [ApiController] - public class ClientsController : ControllerBase + public class ClientsController : Controller { private readonly MqttServerService _mqttServerService; diff --git a/Source/MQTTnet.Server/Controllers/MessagesController.cs b/Source/MQTTnet.Server/Controllers/MessagesController.cs new file mode 100644 index 0000000..6bd00e7 --- /dev/null +++ b/Source/MQTTnet.Server/Controllers/MessagesController.cs @@ -0,0 +1,51 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using MQTTnet.Protocol; +using MQTTnet.Server.Mqtt; + +namespace MQTTnet.Server.Controllers +{ + [Authorize] + [ApiController] + public class MessagesController : Controller + { + private readonly MqttServerService _mqttServerService; + + public MessagesController(MqttServerService mqttServerService) + { + _mqttServerService = mqttServerService ?? throw new ArgumentNullException(nameof(mqttServerService)); + } + + [Route("api/v1/messages")] + [HttpPost] + public async Task PostMessage(MqttApplicationMessage message) + { + await _mqttServerService.PublishAsync(message); + return Ok(); + } + + [Route("api/v1/messages/{*topic}")] + [HttpPost] + public async Task PostMessage(string topic, int qosLevel = 0) + { + byte[] payload; + + using (var memoryStream = new MemoryStream()) + { + await HttpContext.Request.Body.CopyToAsync(memoryStream); + payload = memoryStream.ToArray(); + } + + var message = new MqttApplicationMessageBuilder() + .WithTopic(topic) + .WithPayload(payload) + .WithQualityOfServiceLevel((MqttQualityOfServiceLevel)qosLevel) + .Build(); + + return await PostMessage(message); + } + } +} diff --git a/Source/MQTTnet.Server/Controllers/RetainedApplicationMessagesController.cs b/Source/MQTTnet.Server/Controllers/RetainedApplicationMessagesController.cs index 9c9f273..030d141 100644 --- a/Source/MQTTnet.Server/Controllers/RetainedApplicationMessagesController.cs +++ b/Source/MQTTnet.Server/Controllers/RetainedApplicationMessagesController.cs @@ -12,7 +12,7 @@ namespace MQTTnet.Server.Controllers { [Authorize] [ApiController] - public class RetainedApplicationMessagesController : ControllerBase + public class RetainedApplicationMessagesController : Controller { private readonly MqttServerService _mqttServerService; diff --git a/Source/MQTTnet.Server/Controllers/ServerController.cs b/Source/MQTTnet.Server/Controllers/ServerController.cs new file mode 100644 index 0000000..cf53bba --- /dev/null +++ b/Source/MQTTnet.Server/Controllers/ServerController.cs @@ -0,0 +1,18 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using System.Reflection; + +namespace MQTTnet.Server.Controllers +{ + [Authorize] + [ApiController] + public class ServerController : Controller + { + [Route("api/v1/server/version")] + [HttpGet] + public ActionResult GetVersion() + { + return Assembly.GetExecutingAssembly().GetCustomAttribute().InformationalVersion; + } + } +} diff --git a/Source/MQTTnet.Server/Controllers/SessionsController.cs b/Source/MQTTnet.Server/Controllers/SessionsController.cs index 463c004..5fd0638 100644 --- a/Source/MQTTnet.Server/Controllers/SessionsController.cs +++ b/Source/MQTTnet.Server/Controllers/SessionsController.cs @@ -13,7 +13,7 @@ namespace MQTTnet.Server.Controllers { [Authorize] [ApiController] - public class SessionsController : ControllerBase + public class SessionsController : Controller { private readonly MqttServerService _mqttServerService; diff --git a/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs b/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs index c5f3afd..00eb0e7 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs @@ -22,14 +22,17 @@ namespace MQTTnet.Server.Mqtt { try { + var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; + var pythonContext = new PythonDictionary { + { "client_id", context.ClientId }, + { "session_items", sessionItems }, + { "retain", context.ApplicationMessage.Retain }, { "accept_publish", context.AcceptPublish }, { "close_connection", context.CloseConnection }, - { "client_id", context.ClientId }, { "topic", context.ApplicationMessage.Topic }, - { "qos", (int)context.ApplicationMessage.QualityOfServiceLevel }, - { "retain", context.ApplicationMessage.Retain } + { "qos", (int)context.ApplicationMessage.QualityOfServiceLevel } }; _pythonScriptHostService.InvokeOptionalFunction("on_intercept_application_message", pythonContext); diff --git a/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs b/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs index 3b1a2fc..d002842 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs @@ -9,6 +9,8 @@ namespace MQTTnet.Server.Mqtt { public class MqttServerConnectionValidator : IMqttServerConnectionValidator { + public const string WrappedSessionItemsKey = "WRAPPED_ITEMS"; + private readonly PythonScriptHostService _pythonScriptHostService; private readonly ILogger _logger; @@ -22,6 +24,8 @@ namespace MQTTnet.Server.Mqtt { try { + var sessionItems = new PythonDictionary(); + var pythonContext = new PythonDictionary { { "endpoint", context.Endpoint }, @@ -33,6 +37,7 @@ namespace MQTTnet.Server.Mqtt { "clean_session", context.CleanSession}, { "authentication_method", context.AuthenticationMethod}, { "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) }, + { "session_items", sessionItems }, { "result", PythonConvert.Pythonfy(context.ReasonCode) } }; @@ -40,6 +45,8 @@ namespace MQTTnet.Server.Mqtt _pythonScriptHostService.InvokeOptionalFunction("on_validate_client_connection", pythonContext); context.ReasonCode = PythonConvert.ParseEnum((string)pythonContext["result"]); + + context.SessionItems[WrappedSessionItemsKey] = sessionItems; } catch (Exception exception) { diff --git a/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs b/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs index 2d37f74..ba99e9f 100644 --- a/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs +++ b/Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs @@ -21,14 +21,16 @@ namespace MQTTnet.Server.Mqtt { try { + var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey]; + var pythonContext = new PythonDictionary { - { "accept_subscription", context.AcceptSubscription }, - { "close_connection", context.CloseConnection }, - { "client_id", context.ClientId }, + { "session_items", sessionItems }, { "topic", context.TopicFilter.Topic }, - { "qos", (int)context.TopicFilter.QualityOfServiceLevel } + { "qos", (int)context.TopicFilter.QualityOfServiceLevel }, + { "accept_subscription", context.AcceptSubscription }, + { "close_connection", context.CloseConnection } }; _pythonScriptHostService.InvokeOptionalFunction("on_intercept_subscription", pythonContext); diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index e29e9d0..27b56ff 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -257,12 +257,12 @@ namespace MQTTnet.Client { var clientWasConnected = IsConnected; - InitiateDisconnect(); - - IsConnected = false; + TryInitiateDisconnect(); try { + IsConnected = false; + if (_adapter != null) { _logger.Verbose("Disconnecting [Timeout={0}]", Options.CommunicationTimeout); @@ -295,7 +295,7 @@ namespace MQTTnet.Client } } - private void InitiateDisconnect() + private void TryInitiateDisconnect() { lock (_disconnectLock) { diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs index a7aefd1..65a1ec9 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs @@ -139,6 +139,13 @@ namespace MQTTnet.Client.Options return this; } + public MqttClientOptionsBuilder WithCredentials(IMqttClientCredentials credentials) + { + _options.Credentials = credentials; + + return this; + } + public MqttClientOptionsBuilder WithExtendedAuthenticationExchangeHandler(IMqttExtendedAuthenticationExchangeHandler handler) { _options.ExtendedAuthenticationExchangeHandler = handler; diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 63adf55..12fd2bb 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -84,9 +84,9 @@ namespace MQTTnet.Implementations if (_options.TlsOptions.UseTls) { var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); - await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); - _stream = sslStream; + + await sslStream.AuthenticateAsClientAsync(_options.Server, LoadCertificates(), _options.TlsOptions.SslProtocol, _options.TlsOptions.IgnoreCertificateRevocationErrors).ConfigureAwait(false); } else { diff --git a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs index 0e28ad0..e3dcab8 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerAdapter.cs @@ -48,7 +48,7 @@ namespace MQTTnet.Implementations throw new ArgumentException("TLS certificate is not set."); } - var tlsCertificate = new X509Certificate2(options.TlsEndpointOptions.Certificate); + var tlsCertificate = new X509Certificate2(options.TlsEndpointOptions.Certificate, options.TlsEndpointOptions.CertificateCredentials.Password); if (!tlsCertificate.HasPrivateKey) { throw new InvalidOperationException("The certificate for TLS encryption must contain the private key."); diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index 62eea00..d57888e 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -61,6 +61,8 @@ namespace MQTTnet.Implementations _socket = new Socket(_addressFamily, SocketType.Stream, ProtocolType.Tcp); + // Usage of socket options is described here: https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.setsocketoption?view=netcore-2.2 + if (_options.ReuseAddress) { _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); @@ -68,7 +70,7 @@ namespace MQTTnet.Implementations if (_options.NoDelay) { - _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.NoDelay, true); + _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); } _socket.Bind(_localEndPoint); @@ -160,7 +162,7 @@ namespace MQTTnet.Implementations if (_tlsCertificate != null) { - var sslStream = new SslStream(stream, false); + var sslStream = new SslStream(stream, false, _tlsOptions.RemoteCertificateValidationCallback); await sslStream.AuthenticateAsServerAsync( _tlsCertificate, @@ -171,6 +173,11 @@ namespace MQTTnet.Implementations stream = sslStream; clientCertificate = sslStream.RemoteCertificate as X509Certificate2; + + if (clientCertificate == null && sslStream.RemoteCertificate != null) + { + clientCertificate = new X509Certificate2(sslStream.RemoteCertificate.Export(X509ContentType.Cert)); + } } var clientHandler = ClientHandler; diff --git a/Source/MQTTnet/Internal/AsyncLock.cs b/Source/MQTTnet/Internal/AsyncLock.cs index 9b7eefd..17d7404 100644 --- a/Source/MQTTnet/Internal/AsyncLock.cs +++ b/Source/MQTTnet/Internal/AsyncLock.cs @@ -23,7 +23,7 @@ namespace MQTTnet.Internal public Task WaitAsync(CancellationToken cancellationToken) { var task = _semaphore.WaitAsync(cancellationToken); - if (task.IsCompleted) + if (task.Status == TaskStatus.RanToCompletion) { return _releaser; } diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index 5e0251d..61ca517 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -41,13 +41,13 @@ RELEASE;NETSTANDARD1_3 - + - + diff --git a/Source/MQTTnet/Server/IMqttServerCredentials.cs b/Source/MQTTnet/Server/IMqttServerCredentials.cs new file mode 100644 index 0000000..5e75be9 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerCredentials.cs @@ -0,0 +1,6 @@ +using System; + +public interface IMqttServerCredentials +{ + String Password { get; } +} diff --git a/Source/MQTTnet/Server/IMqttServerOptions.cs b/Source/MQTTnet/Server/IMqttServerOptions.cs index 3a24289..7c5fde4 100644 --- a/Source/MQTTnet/Server/IMqttServerOptions.cs +++ b/Source/MQTTnet/Server/IMqttServerOptions.cs @@ -21,6 +21,8 @@ namespace MQTTnet.Server MqttServerTcpEndpointOptions DefaultEndpointOptions { get; } MqttServerTlsTcpEndpointOptions TlsEndpointOptions { get; } - IMqttServerStorage Storage { get; } + IMqttServerStorage Storage { get; } + + } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/IMqttServerPersistedSession.cs b/Source/MQTTnet/Server/IMqttServerPersistedSession.cs new file mode 100644 index 0000000..18f7165 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerPersistedSession.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; + +namespace MQTTnet.Server +{ + public interface IMqttServerPersistedSession + { + string ClientId { get; } + + IDictionary Items { get; } + + IList Subscriptions { get; } + + MqttApplicationMessage WillMessage { get; } + + uint? WillDelayInterval { get; } + + DateTime? SessionExpiryTimestamp { get; } + + IList PendingApplicationMessages { get; } + } +} diff --git a/Source/MQTTnet/Server/IMqttServerPersistedSessionsStorage.cs b/Source/MQTTnet/Server/IMqttServerPersistedSessionsStorage.cs new file mode 100644 index 0000000..fa2a4cb --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerPersistedSessionsStorage.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttServerPersistedSessionsStorage + { + Task> LoadPersistedSessionsAsync(); + } +} diff --git a/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs b/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs index 5612601..11efa57 100644 --- a/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs +++ b/Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs @@ -1,17 +1,25 @@ -namespace MQTTnet.Server +using System.Collections.Generic; + +namespace MQTTnet.Server { public class MqttApplicationMessageInterceptorContext { - public MqttApplicationMessageInterceptorContext(string clientId, MqttApplicationMessage applicationMessage) + public MqttApplicationMessageInterceptorContext(string clientId, IDictionary sessionItems, MqttApplicationMessage applicationMessage) { ClientId = clientId; ApplicationMessage = applicationMessage; + SessionItems = sessionItems; } public string ClientId { get; } public MqttApplicationMessage ApplicationMessage { get; set; } + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + public bool AcceptPublish { get; set; } = true; public bool CloseConnection { get; set; } diff --git a/Source/MQTTnet/Server/MqttClientConnection.cs b/Source/MQTTnet/Server/MqttClientConnection.cs index ed378bb..e71d1a8 100644 --- a/Source/MQTTnet/Server/MqttClientConnection.cs +++ b/Source/MQTTnet/Server/MqttClientConnection.cs @@ -31,7 +31,6 @@ namespace MQTTnet.Server private readonly IMqttChannelAdapter _channelAdapter; private readonly IMqttDataConverter _dataConverter; private readonly string _endpoint; - private readonly MqttConnectPacket _connectPacket; private readonly DateTime _connectedTimestamp; private Task _packageReceiverTask; @@ -60,22 +59,24 @@ namespace MQTTnet.Server _channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter)); _dataConverter = _channelAdapter.PacketFormatterAdapter.DataConverter; _endpoint = _channelAdapter.Endpoint; - _connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); + ConnectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); if (logger == null) throw new ArgumentNullException(nameof(logger)); _logger = logger.CreateChildLogger(nameof(MqttClientConnection)); - _keepAliveMonitor = new MqttClientKeepAliveMonitor(_connectPacket.ClientId, StopAsync, _logger); + _keepAliveMonitor = new MqttClientKeepAliveMonitor(ConnectPacket.ClientId, StopAsync, _logger); _connectedTimestamp = DateTime.UtcNow; _lastPacketReceivedTimestamp = _connectedTimestamp; _lastNonKeepAlivePacketReceivedTimestamp = _lastPacketReceivedTimestamp; } - public string ClientId => _connectPacket.ClientId; + public MqttConnectPacket ConnectPacket { get; } - public MqttClientSession Session { get; } + public string ClientId => ConnectPacket.ClientId; + public MqttClientSession Session { get; } + public async Task StopAsync() { StopInternal(); @@ -133,12 +134,12 @@ namespace MQTTnet.Server _channelAdapter.ReadingPacketStartedCallback = OnAdapterReadingPacketStarted; _channelAdapter.ReadingPacketCompletedCallback = OnAdapterReadingPacketCompleted; - Session.WillMessage = _connectPacket.WillMessage; + Session.WillMessage = ConnectPacket.WillMessage; Task.Run(() => SendPendingPacketsAsync(_cancellationToken.Token), _cancellationToken.Token).Forget(_logger); // TODO: Change to single thread in SessionManager. Or use SessionManager and stats from KeepAliveMonitor. - _keepAliveMonitor.Start(_connectPacket.KeepAlivePeriod, _cancellationToken.Token); + _keepAliveMonitor.Start(ConnectPacket.KeepAlivePeriod, _cancellationToken.Token); await SendAsync( new MqttConnAckPacket @@ -228,7 +229,7 @@ namespace MQTTnet.Server } else { - _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); + _logger.Error(exception, "Client '{0}': Error while receiving client packets.", ClientId); } StopInternal(); @@ -271,7 +272,7 @@ namespace MQTTnet.Server private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket) { // TODO: Let the channel adapter create the packet. - var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); + var subscribeResult = await Session.SubscriptionsManager.SubscribeAsync(subscribePacket, ConnectPacket).ConfigureAwait(false); await SendAsync(subscribeResult.ResponsePacket).ConfigureAwait(false); diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 804a223..d165001 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -12,11 +12,12 @@ namespace MQTTnet.Server private readonly DateTime _createdTimestamp = DateTime.UtcNow; - public MqttClientSession(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger) + public MqttClientSession(string clientId, IDictionary items, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions, IMqttNetChildLogger logger) { ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + Items = items ?? throw new ArgumentNullException(nameof(items)); - SubscriptionsManager = new MqttClientSubscriptionsManager(clientId, eventDispatcher, serverOptions); + SubscriptionsManager = new MqttClientSubscriptionsManager(this, eventDispatcher, serverOptions); ApplicationMessagesQueue = new MqttClientSessionApplicationMessagesQueue(serverOptions); if (logger == null) throw new ArgumentNullException(nameof(logger)); @@ -33,6 +34,11 @@ namespace MQTTnet.Server public MqttClientSessionApplicationMessagesQueue ApplicationMessagesQueue { get; } + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary Items { get; } + public void EnqueueApplicationMessage(MqttApplicationMessage applicationMessage, string senderClientId, bool isRetainedApplicationMessage) { var checkSubscriptionsResult = SubscriptionsManager.CheckSubscriptions(applicationMessage.Topic, applicationMessage.QualityOfServiceLevel); diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index c2d5637..db70e95 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -20,7 +20,8 @@ namespace MQTTnet.Server private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1); private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); - + private readonly IDictionary _serverSessionItems = new ConcurrentDictionary(); + private readonly CancellationToken _cancellationToken; private readonly MqttServerEventDispatcher _eventDispatcher; @@ -241,19 +242,19 @@ namespace MQTTnet.Server clientId = connectPacket.ClientId; - var validatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); + var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); - if (validatorContext.ReasonCode != MqttConnectReasonCode.Success) + if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { // Send failure response here without preparing a session. The result for a successful connect // will be sent from the session itself. - var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(validatorContext); + var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); await channelAdapter.SendPacketAsync(connAckPacket, _options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); return; } - var connection = await CreateConnectionAsync(channelAdapter, connectPacket).ConfigureAwait(false); + var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); @@ -289,7 +290,7 @@ namespace MQTTnet.Server private async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) { - var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter); + var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary()); var connectionValidator = _options.ConnectionValidator; @@ -302,8 +303,7 @@ namespace MQTTnet.Server await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false); // Check the client ID and set a random one if supported. - if (string.IsNullOrEmpty(connectPacket.ClientId) && - channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500) + if (string.IsNullOrEmpty(connectPacket.ClientId) && channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500) { connectPacket.ClientId = context.AssignedClientIdentifier; } @@ -316,7 +316,7 @@ namespace MQTTnet.Server return context; } - private async Task CreateConnectionAsync(IMqttChannelAdapter channelAdapter, MqttConnectPacket connectPacket) + private async Task CreateConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) { await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false); try @@ -345,7 +345,7 @@ namespace MQTTnet.Server if (session == null) { - session = new MqttClientSession(connectPacket.ClientId, _eventDispatcher, _options, _logger); + session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _logger); _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId); } @@ -362,7 +362,7 @@ namespace MQTTnet.Server } } - private async Task InterceptApplicationMessageAsync(MqttClientConnection sender, MqttApplicationMessage applicationMessage) + private async Task InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage) { var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) @@ -370,13 +370,22 @@ namespace MQTTnet.Server return null; } - var senderClientId = sender?.ClientId; - if (sender == null) + string senderClientId; + IDictionary sessionItems; + + var messageIsFromServer = senderConnection == null; + if (messageIsFromServer) { senderClientId = _options.ClientId; + sessionItems = _serverSessionItems; + } + else + { + senderClientId = senderConnection.ClientId; + sessionItems = senderConnection.Session.Items; } - var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, applicationMessage); + var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, applicationMessage); await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); return interceptorContext; } diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index e2024a6..c84a018 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -9,21 +9,24 @@ namespace MQTTnet.Server { public class MqttClientSubscriptionsManager { - private readonly Dictionary _subscriptions = new Dictionary(); - private readonly IMqttServerOptions _options; + private readonly Dictionary _subscriptions = new Dictionary(); + private readonly MqttClientSession _clientSession; + private readonly IMqttServerOptions _serverOptions; private readonly MqttServerEventDispatcher _eventDispatcher; - private readonly string _clientId; - public MqttClientSubscriptionsManager(string clientId, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions options) + public MqttClientSubscriptionsManager(MqttClientSession clientSession, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions) { - _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); - _options = options ?? throw new ArgumentNullException(nameof(options)); + _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); + + // TODO: Consider removing the server options here and build a new class "ISubscriptionInterceptor" and just pass it. The instance is generated in the root server class upon start. + _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions)); _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher)); } - public async Task SubscribeAsync(MqttSubscribePacket subscribePacket) + public async Task SubscribeAsync(MqttSubscribePacket subscribePacket, MqttConnectPacket connectPacket) { if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); + if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); var result = new MqttClientSubscribeResult { @@ -61,10 +64,10 @@ namespace MQTTnet.Server { lock (_subscriptions) { - _subscriptions[finalTopicFilter.Topic] = finalTopicFilter.QualityOfServiceLevel; + _subscriptions[finalTopicFilter.Topic] = finalTopicFilter; } - await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, finalTopicFilter).ConfigureAwait(false); + await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false); } } @@ -73,6 +76,8 @@ namespace MQTTnet.Server public async Task SubscribeAsync(IEnumerable topicFilters) { + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + foreach (var topicFilter in topicFilters) { var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); @@ -85,10 +90,10 @@ namespace MQTTnet.Server { lock (_subscriptions) { - _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel; + _subscriptions[topicFilter.Topic] = topicFilter; } - await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false); + await _eventDispatcher.HandleClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); } } } @@ -119,7 +124,7 @@ namespace MQTTnet.Server foreach (var topicFilter in unsubscribePacket.TopicFilters) { - await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientId, topicFilter).ConfigureAwait(false); + await _eventDispatcher.HandleClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false); } return unsubAckPacket; @@ -153,7 +158,7 @@ namespace MQTTnet.Server continue; } - qosLevels.Add(subscription.Value); + qosLevels.Add(subscription.Value.QualityOfServiceLevel); } } @@ -192,10 +197,10 @@ namespace MQTTnet.Server private async Task InterceptSubscribeAsync(TopicFilter topicFilter) { - var context = new MqttSubscriptionInterceptorContext(_clientId, topicFilter); - if (_options.SubscriptionInterceptor != null) + var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items); + if (_serverOptions.SubscriptionInterceptor != null) { - await _options.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); + await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); } return context; diff --git a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs index 45dba13..9a5b8b9 100644 --- a/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs +++ b/Source/MQTTnet/Server/MqttConnectionValidatorContext.cs @@ -14,53 +14,59 @@ namespace MQTTnet.Server private readonly MqttConnectPacket _connectPacket; private readonly IMqttChannelAdapter _clientAdapter; - public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) + public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter, IDictionary sessionItems) { - _connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket)); + _connectPacket = connectPacket; _clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter)); + SessionItems = sessionItems; } public string ClientId => _connectPacket.ClientId; - public string Username => _connectPacket.Username; + public string Endpoint => _clientAdapter.Endpoint; - public byte[] RawPassword => _connectPacket.Password; + public bool IsSecureConnection => _clientAdapter.IsSecureConnection; - public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); + public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate; - public MqttApplicationMessage WillMessage => _connectPacket.WillMessage; + public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion; - public bool CleanSession => _connectPacket.CleanSession; + public string Username => _connectPacket?.Username; - public ushort KeepAlivePeriod => _connectPacket.KeepAlivePeriod; + public byte[] RawPassword => _connectPacket?.Password; - public List UserProperties => _connectPacket.Properties?.UserProperties; + public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]); - public byte[] AuthenticationData => _connectPacket.Properties?.AuthenticationData; + public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage; - public string AuthenticationMethod => _connectPacket.Properties?.AuthenticationMethod; + public bool? CleanSession => _connectPacket?.CleanSession; - public uint? MaximumPacketSize => _connectPacket.Properties?.MaximumPacketSize; + public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod; - public ushort? ReceiveMaximum => _connectPacket.Properties?.ReceiveMaximum; + public List UserProperties => _connectPacket?.Properties?.UserProperties; - public ushort? TopicAliasMaximum => _connectPacket.Properties?.TopicAliasMaximum; + public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData; - public bool? RequestProblemInformation => _connectPacket.Properties?.RequestProblemInformation; + public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod; - public bool? RequestResponseInformation => _connectPacket.Properties?.RequestResponseInformation; + public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize; - public uint? SessionExpiryInterval => _connectPacket.Properties?.SessionExpiryInterval; + public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum; - public uint? WillDelayInterval => _connectPacket.Properties?.WillDelayInterval; + public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum; - public string Endpoint => _clientAdapter.Endpoint; + public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation; - public bool IsSecureConnection => _clientAdapter.IsSecureConnection; + public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation; - public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate; + public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval; - public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion; + public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval; + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } /// /// This is used for MQTTv3 only. diff --git a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs index 357cb82..5991e7d 100644 --- a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs +++ b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs @@ -1,5 +1,6 @@ using System; using System.Net; +using System.Net.Security; using System.Security.Authentication; namespace MQTTnet.Server @@ -81,9 +82,10 @@ namespace MQTTnet.Server return this; } - public MqttServerOptionsBuilder WithEncryptionCertificate(byte[] value) + public MqttServerOptionsBuilder WithEncryptionCertificate(byte[] value, IMqttServerCredentials credentials = null) { _options.TlsEndpointOptions.Certificate = value; + _options.TlsEndpointOptions.CertificateCredentials = credentials; return this; } @@ -93,11 +95,29 @@ namespace MQTTnet.Server return this; } +#if !WINDOWS_UWP + public MqttServerOptionsBuilder WithClientCertificate(RemoteCertificateValidationCallback validationCallback = null, bool checkCertificateRevocation = false) + { + _options.TlsEndpointOptions.ClientCertificateRequired = true; + _options.TlsEndpointOptions.CheckCertificateRevocation = checkCertificateRevocation; + _options.TlsEndpointOptions.RemoteCertificateValidationCallback = validationCallback; + return this; + } +#endif + public MqttServerOptionsBuilder WithoutEncryptedEndpoint() { _options.TlsEndpointOptions.IsEnabled = false; return this; } + +#if !WINDOWS_UWP + public MqttServerOptionsBuilder WithRemoteCertificateValidationCallback(RemoteCertificateValidationCallback value) + { + _options.TlsEndpointOptions.RemoteCertificateValidationCallback = value; + return this; + } +#endif public MqttServerOptionsBuilder WithStorage(IMqttServerStorage value) { diff --git a/Source/MQTTnet/Server/MqttServerTlsTcpEndpointOptions.cs b/Source/MQTTnet/Server/MqttServerTlsTcpEndpointOptions.cs index 212b052..e92d987 100644 --- a/Source/MQTTnet/Server/MqttServerTlsTcpEndpointOptions.cs +++ b/Source/MQTTnet/Server/MqttServerTlsTcpEndpointOptions.cs @@ -1,4 +1,5 @@ -using System.Security.Authentication; +using System.Net.Security; +using System.Security.Authentication; namespace MQTTnet.Server { @@ -11,10 +12,15 @@ namespace MQTTnet.Server public byte[] Certificate { get; set; } + public IMqttServerCredentials CertificateCredentials { get; set; } + public bool ClientCertificateRequired { get; set; } public bool CheckCertificateRevocation { get; set; } - + +#if !WINDOWS_UWP + public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; } +#endif public SslProtocols SslProtocol { get; set; } = SslProtocols.Tls12; } } diff --git a/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs b/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs index ca98c95..7e3963b 100644 --- a/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs +++ b/Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs @@ -1,19 +1,25 @@ -using System; +using System.Collections.Generic; namespace MQTTnet.Server { public class MqttSubscriptionInterceptorContext { - public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter) + public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, IDictionary sessionItems) { ClientId = clientId; - TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter)); + TopicFilter = topicFilter; + SessionItems = sessionItems; } public string ClientId { get; } public TopicFilter TopicFilter { get; set; } - + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this session. + /// + public IDictionary SessionItems { get; } + public bool AcceptSubscription { get; set; } = true; public bool CloseConnection { get; set; } diff --git a/Tests/MQTTnet.AspNetCore.Tests/Mockups/DuplexPipeMockup.cs b/Tests/MQTTnet.AspNetCore.Tests/Mockups/DuplexPipeMockup.cs index 1774f18..306749b 100644 --- a/Tests/MQTTnet.AspNetCore.Tests/Mockups/DuplexPipeMockup.cs +++ b/Tests/MQTTnet.AspNetCore.Tests/Mockups/DuplexPipeMockup.cs @@ -4,13 +4,21 @@ namespace MQTTnet.AspNetCore.Tests.Mockups { public class DuplexPipeMockup : IDuplexPipe { + public DuplexPipeMockup() + { + var pool = new LimitedMemoryPool(); + var pipeOptions = new PipeOptions(pool); + Receive = new Pipe(pipeOptions); + Send = new Pipe(pipeOptions); + } + PipeReader IDuplexPipe.Input => Receive.Reader; PipeWriter IDuplexPipe.Output => Send.Writer; - public Pipe Receive { get; set; } = new Pipe(); + public Pipe Receive { get; set; } - public Pipe Send { get; set; } = new Pipe(); + public Pipe Send { get; set; } } } diff --git a/Tests/MQTTnet.AspNetCore.Tests/Mockups/LimitedMemoryPool.cs b/Tests/MQTTnet.AspNetCore.Tests/Mockups/LimitedMemoryPool.cs new file mode 100644 index 0000000..ac5c23c --- /dev/null +++ b/Tests/MQTTnet.AspNetCore.Tests/Mockups/LimitedMemoryPool.cs @@ -0,0 +1,18 @@ +using System.Buffers; + +namespace MQTTnet.AspNetCore.Tests.Mockups +{ + public class LimitedMemoryPool : MemoryPool + { + protected override void Dispose(bool disposing) + { + } + + public override IMemoryOwner Rent(int minBufferSize = -1) + { + return new MemoryOwner(minBufferSize); + } + + public override int MaxBufferSize { get; } + } +} \ No newline at end of file diff --git a/Tests/MQTTnet.AspNetCore.Tests/Mockups/MemoryOwner.cs b/Tests/MQTTnet.AspNetCore.Tests/Mockups/MemoryOwner.cs new file mode 100644 index 0000000..1b7b02f --- /dev/null +++ b/Tests/MQTTnet.AspNetCore.Tests/Mockups/MemoryOwner.cs @@ -0,0 +1,33 @@ +using System; +using System.Buffers; + +namespace MQTTnet.AspNetCore.Tests.Mockups +{ + public class MemoryOwner : IMemoryOwner + { + private readonly byte[] _raw; + + public MemoryOwner(int size) + { + if (size <= 0) + { + size = 1024; + } + + if (size > 4096) + { + size = 4096; + } + + _raw = ArrayPool.Shared.Rent(size); + Memory = _raw; + } + + public void Dispose() + { + ArrayPool.Shared.Return(_raw); + } + + public Memory Memory { get; } + } +} \ No newline at end of file diff --git a/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs b/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs index 8fb74ae..f916779 100644 --- a/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs +++ b/Tests/MQTTnet.AspNetCore.Tests/MqttConnectionContextTest.cs @@ -47,5 +47,21 @@ namespace MQTTnet.AspNetCore.Tests await Task.WhenAll(tasks).ConfigureAwait(false); } + + + [TestMethod] + public async Task TestLargePacket() + { + var serializer = new MqttPacketFormatterAdapter(MqttProtocolVersion.V311); + var pipe = new DuplexPipeMockup(); + var connection = new DefaultConnectionContext(); + connection.Transport = pipe; + var ctx = new MqttConnectionContext(serializer, connection); + + await ctx.SendPacketAsync(new MqttPublishPacket() { Payload = new byte[20_000] }, TimeSpan.Zero, CancellationToken.None).ConfigureAwait(false); + + var readResult = await pipe.Send.Reader.ReadAsync(); + Assert.IsTrue(readResult.Buffer.Length > 20000); + } } } diff --git a/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs index ca665a1..0aeea6d 100644 --- a/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/ManagedMqttClient_Tests.cs @@ -1,4 +1,5 @@ -using System.Threading; +using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; @@ -108,5 +109,92 @@ namespace MQTTnet.Tests Assert.AreEqual(0, (await server.GetClientStatusAsync()).Count); } } + + [TestMethod] + public async Task Storage_Queue_Drains() + { + using (var testEnvironment = new TestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + testEnvironment.IgnoreServerLogErrors = true; + + var factory = new MqttFactory(); + + var server = await testEnvironment.StartServerAsync(); + + var managedClient = new ManagedMqttClient(testEnvironment.CreateClient(), new MqttNetLogger().CreateChildLogger()); + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost", testEnvironment.ServerPort); + var storage = new ManagedMqttClientTestStorage(); + + TaskCompletionSource connected = new TaskCompletionSource(); + managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => + { + managedClient.ConnectedHandler = null; + connected.SetResult(true); + }); + + await managedClient.StartAsync(new ManagedMqttClientOptionsBuilder() + .WithClientOptions(clientOptions) + .WithStorage(storage) + .WithAutoReconnectDelay(System.TimeSpan.FromSeconds(5)) + .Build()); + + await connected.Task; + + await testEnvironment.Server.StopAsync(); + + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "1" }); + + //Message should have been added to the storage queue in PublishAsync, + //and we are awaiting PublishAsync, so the message should already be + //in storage at this point (i.e. no waiting). + Assert.AreEqual(1, storage.GetMessageCount()); + + connected = new TaskCompletionSource(); + managedClient.ConnectedHandler = new MqttClientConnectedHandlerDelegate(e => + { + managedClient.ConnectedHandler = null; + connected.SetResult(true); + }); + + await testEnvironment.Server.StartAsync(new MqttServerOptionsBuilder() + .WithDefaultEndpointPort(testEnvironment.ServerPort).Build()); + + await connected.Task; + + //Wait 500ms here so the client has time to publish the queued message + await Task.Delay(500); + + Assert.AreEqual(0, storage.GetMessageCount()); + + await managedClient.StopAsync(); + } + } + } + + public class ManagedMqttClientTestStorage : IManagedMqttClientStorage + { + private IList _messages = null; + + public Task> LoadQueuedMessagesAsync() + { + if (_messages == null) + { + _messages = new List(); + } + return Task.FromResult(_messages); + } + + public Task SaveQueuedMessagesAsync(IList messages) + { + _messages = messages; + return Task.FromResult(0); + } + + public int GetMessageCount() + { + return _messages.Count; + } } -} \ No newline at end of file +} diff --git a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs index 155b55e..51d1753 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs @@ -109,6 +109,83 @@ namespace MQTTnet.Tests } } + [TestMethod] + public async Task Reconnect_While_Server_Offline() + { + using (var testEnvironment = new TestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + + var server = await testEnvironment.StartServerAsync(); + var client = await testEnvironment.ConnectClientAsync(); + + await Task.Delay(500); + Assert.IsTrue(client.IsConnected); + + await server.StopAsync(); + await Task.Delay(500); + Assert.IsFalse(client.IsConnected); + + for (var i = 0; i < 5; i++) + { + try + { + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1", testEnvironment.ServerPort).Build()); + Assert.Fail("Must fail!"); + } + catch + { + } + } + + await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultEndpointPort(testEnvironment.ServerPort).Build()); + await Task.Delay(500); + + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1", testEnvironment.ServerPort).Build()); + Assert.IsTrue(client.IsConnected); + } + } + + [TestMethod] + public async Task Reconnect_From_Disconnected_Event() + { + using (var testEnvironment = new TestEnvironment()) + { + testEnvironment.IgnoreClientLogErrors = true; + + var client = testEnvironment.CreateClient(); + + var tries = 0; + var maxTries = 3; + + client.UseDisconnectedHandler(async e => + { + if (tries >= maxTries) + { + return; + } + + Interlocked.Increment(ref tries); + + await Task.Delay(100); + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1", testEnvironment.ServerPort).Build()); + }); + + try + { + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("127.0.0.1", testEnvironment.ServerPort).Build()); + Assert.Fail("Must fail!"); + } + catch + { + } + + SpinWait.SpinUntil(() => tries >= maxTries, 10000); + + Assert.AreEqual(maxTries, tries); + } + } + [TestMethod] public async Task PacketIdentifier_In_Publish_Result() { @@ -158,6 +235,25 @@ namespace MQTTnet.Tests } } + [TestMethod] + public async Task Fire_Disconnected_Event_On_Server_Shutdown() + { + using (var testEnvironment = new TestEnvironment()) + { + var server = await testEnvironment.StartServerAsync(); + var client = await testEnvironment.ConnectClientAsync(); + + var handlerFired = false; + client.UseDisconnectedHandler(e => handlerFired = true); + + await server.StopAsync(); + + await Task.Delay(4000); + + Assert.IsTrue(handlerFired); + } + } + [TestMethod] public async Task Disconnect_Event_Contains_Exception() { diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs index 1c4ec84..6f0d542 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManager_Tests.cs @@ -1,4 +1,6 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Collections.Concurrent; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Packets; using MQTTnet.Protocol; using MQTTnet.Server; @@ -10,14 +12,17 @@ namespace MQTTnet.Tests public class MqttSubscriptionsManager_Tests { [TestMethod] - public void MqttSubscriptionsManager_SubscribeSingleSuccess() + public async Task MqttSubscriptionsManager_SubscribeSingleSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce); Assert.IsTrue(result.IsSubscribed); @@ -25,14 +30,17 @@ namespace MQTTnet.Tests } [TestMethod] - public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() + public async Task MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); @@ -40,15 +48,18 @@ namespace MQTTnet.Tests } [TestMethod] - public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess() + public async Task MqttSubscriptionsManager_SubscribeTwoTimesSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); @@ -56,33 +67,39 @@ namespace MQTTnet.Tests } [TestMethod] - public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() + public async Task MqttSubscriptionsManager_SubscribeSingleNoSuccess() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } [TestMethod] - public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() + public async Task MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() { - var sm = new MqttClientSubscriptionsManager("", new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); + var s = new MqttClientSession("", new ConcurrentDictionary(), + new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions(), new TestLogger()); + + var sm = new MqttClientSubscriptionsManager(s, new MqttServerEventDispatcher(new TestLogger()), new MqttServerOptions()); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp).GetAwaiter().GetResult(); + await sm.SubscribeAsync(sp, new MqttConnectPacket()); Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); var up = new MqttUnsubscribePacket(); up.TopicFilters.Add("A/B/C"); - sm.UnsubscribeAsync(up); + await sm.UnsubscribeAsync(up); Assert.IsFalse(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } diff --git a/Tests/MQTTnet.Core.Tests/RPC_Tests.cs b/Tests/MQTTnet.Core.Tests/RPC_Tests.cs index b0babfa..9f03172 100644 --- a/Tests/MQTTnet.Core.Tests/RPC_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/RPC_Tests.cs @@ -8,6 +8,8 @@ using MQTTnet.Client.Receiving; using MQTTnet.Exceptions; using MQTTnet.Extensions.Rpc; using MQTTnet.Protocol; +using MQTTnet.Client.Options; +using MQTTnet.Formatter; namespace MQTTnet.Tests { @@ -15,26 +17,39 @@ namespace MQTTnet.Tests public class RPC_Tests { [TestMethod] - public async Task Execute_Success() + public Task Execute_Success_With_QoS_0() { - using (var testEnvironment = new TestEnvironment()) - { - await testEnvironment.StartServerAsync(); - var responseSender = await testEnvironment.ConnectClientAsync(); - await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping"); + return Execute_Success(MqttQualityOfServiceLevel.AtMostOnce, MqttProtocolVersion.V311); + } - responseSender.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(async e => - { - await responseSender.PublishAsync(e.ApplicationMessage.Topic + "/response", "pong"); - }); + [TestMethod] + public Task Execute_Success_With_QoS_1() + { + return Execute_Success(MqttQualityOfServiceLevel.AtLeastOnce, MqttProtocolVersion.V311); + } - var requestSender = await testEnvironment.ConnectClientAsync(); + [TestMethod] + public Task Execute_Success_With_QoS_2() + { + return Execute_Success(MqttQualityOfServiceLevel.ExactlyOnce, MqttProtocolVersion.V311); + } - var rpcClient = new MqttRpcClient(requestSender); - var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(5), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); + [TestMethod] + public Task Execute_Success_With_QoS_0_MQTT_V5() + { + return Execute_Success(MqttQualityOfServiceLevel.AtMostOnce, MqttProtocolVersion.V500); + } - Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); - } + [TestMethod] + public Task Execute_Success_With_QoS_1_MQTT_V5() + { + return Execute_Success(MqttQualityOfServiceLevel.AtLeastOnce, MqttProtocolVersion.V500); + } + + [TestMethod] + public Task Execute_Success_With_QoS_2_MQTT_V5() + { + return Execute_Success(MqttQualityOfServiceLevel.ExactlyOnce, MqttProtocolVersion.V500); } [TestMethod] @@ -51,5 +66,27 @@ namespace MQTTnet.Tests await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); } } + + private async Task Execute_Success(MqttQualityOfServiceLevel qosLevel, MqttProtocolVersion protocolVersion) + { + using (var testEnvironment = new TestEnvironment()) + { + await testEnvironment.StartServerAsync(); + var responseSender = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithProtocolVersion(protocolVersion)); + await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping"); + + responseSender.ApplicationMessageReceivedHandler = new MqttApplicationMessageReceivedHandlerDelegate(async e => + { + await responseSender.PublishAsync(e.ApplicationMessage.Topic + "/response", "pong"); + }); + + var requestSender = await testEnvironment.ConnectClientAsync(); + + var rpcClient = new MqttRpcClient(requestSender); + var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(5), "ping", "", qosLevel); + + Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); + } + } } } diff --git a/Tests/MQTTnet.Core.Tests/Session_Tests.cs b/Tests/MQTTnet.Core.Tests/Session_Tests.cs new file mode 100644 index 0000000..d06bd4e --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/Session_Tests.cs @@ -0,0 +1,61 @@ +using System.Text; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Client; +using MQTTnet.Client.Subscribing; +using MQTTnet.Server; +using MQTTnet.Tests.Mockups; + +namespace MQTTnet.Tests +{ + [TestClass] + public class Session_Tests + { + [TestMethod] + public async Task Set_Session_Item() + { + using (var testEnvironment = new TestEnvironment()) + { + var serverOptions = new MqttServerOptionsBuilder() + .WithConnectionValidator(delegate (MqttConnectionValidatorContext context) + { + // Don't validate anything. Just set some session items. + context.SessionItems["can_subscribe_x"] = true; + context.SessionItems["default_payload"] = "Hello World"; + }) + .WithSubscriptionInterceptor(delegate (MqttSubscriptionInterceptorContext context) + { + if (context.TopicFilter.Topic == "x") + { + context.AcceptSubscription = context.SessionItems["can_subscribe_x"] as bool? == true; + } + }) + .WithApplicationMessageInterceptor(delegate (MqttApplicationMessageInterceptorContext context) + { + context.ApplicationMessage.Payload = Encoding.UTF8.GetBytes(context.SessionItems["default_payload"] as string); + }); + + await testEnvironment.StartServerAsync(serverOptions); + + string receivedPayload = null; + + var client = await testEnvironment.ConnectClientAsync(); + client.UseApplicationMessageReceivedHandler(delegate(MqttApplicationMessageReceivedEventArgs args) + { + receivedPayload = args.ApplicationMessage.ConvertPayloadToString(); + }); + + var subscribeResult = await client.SubscribeAsync("x"); + + Assert.AreEqual(MqttClientSubscribeResultCode.GrantedQoS0, subscribeResult.Items[0].ResultCode); + + var client2 = await testEnvironment.ConnectClientAsync(); + await client2.PublishAsync("x"); + + await Task.Delay(1000); + + Assert.AreEqual("Hello World", receivedPayload); + } + } + } +}