diff --git a/Frameworks/MQTTnet.NetStandard/Server/IMqttServer.cs b/Frameworks/MQTTnet.NetStandard/Server/IMqttServer.cs index 926ab9d..99ad0a0 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/IMqttServer.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/IMqttServer.cs @@ -12,6 +12,9 @@ namespace MQTTnet.Server Task> GetConnectedClientsAsync(); + Task SubscribeAsync(string clientId, IList topicFilters); + Task UnsubscribeAsync(string clientId, IList topicFilters); + Task StartAsync(IMqttServerOptions options); Task StopAsync(); } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs index 9615f86..4098616 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; @@ -43,7 +44,7 @@ namespace MQTTnet.Server ClientId = clientId; - _subscriptionsManager = new MqttClientSubscriptionsManager(_options); + _subscriptionsManager = new MqttClientSubscriptionsManager(_options, clientId); _pendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger); } @@ -143,6 +144,24 @@ namespace MQTTnet.Server _pendingMessagesQueue.Enqueue(publishPacket); } + public Task SubscribeAsync(IList topicFilters) + { + return _subscriptionsManager.SubscribeAsync(new MqttSubscribePacket + { + TopicFilters = topicFilters + }); + } + + public Task UnsubscribeAsync(IList topicFilters) + { + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + return _subscriptionsManager.UnsubscribeAsync(new MqttUnsubscribePacket + { + TopicFilters = topicFilters + }); + } + public void Dispose() { _pendingMessagesQueue?.Dispose(); @@ -231,7 +250,7 @@ namespace MQTTnet.Server private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { - var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket, ClientId); + var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, subscribeResult.ResponsePacket).ConfigureAwait(false); if (subscribeResult.CloseConnection) @@ -245,8 +264,7 @@ namespace MQTTnet.Server private async Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { - var unsubscribeResult = await _subscriptionsManager.UnsubscribeAsync(unsubscribePacket); - + var unsubscribeResult = await _subscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, unsubscribeResult); } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs index 28dac0d..6fc8a71 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs @@ -168,6 +168,48 @@ namespace MQTTnet.Server } } + public async Task SubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + await _semaphore.WaitAsync().ConfigureAwait(false); + try + { + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session {clientId} is unknown."); + } + + await session.SubscribeAsync(topicFilters); + } + finally + { + _semaphore.Release(); + } + } + + public async Task UnsubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + await _semaphore.WaitAsync().ConfigureAwait(false); + try + { + if (!_sessions.TryGetValue(clientId, out var session)) + { + throw new InvalidOperationException($"Client session {clientId} is unknown."); + } + + await session.UnsubscribeAsync(topicFilters); + } + finally + { + _semaphore.Release(); + } + } + private MqttApplicationMessage InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) { if (_options.ApplicationMessageInterceptor == null) diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs index 2151f77..77949d4 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs @@ -13,13 +13,15 @@ namespace MQTTnet.Server private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly Dictionary _subscriptions = new Dictionary(); private readonly IMqttServerOptions _options; + private readonly string _clientId; - public MqttClientSubscriptionsManager(IMqttServerOptions options) + public MqttClientSubscriptionsManager(IMqttServerOptions options, string clientId) { _options = options ?? throw new ArgumentNullException(nameof(options)); + _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); } - public async Task SubscribeAsync(MqttSubscribePacket subscribePacket, string clientId) + public async Task SubscribeAsync(MqttSubscribePacket subscribePacket) { if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); @@ -34,7 +36,7 @@ namespace MQTTnet.Server { foreach (var topicFilter in subscribePacket.TopicFilters) { - var interceptorContext = InterceptSubscribe(clientId, topicFilter); + var interceptorContext = InterceptSubscribe(topicFilter); if (!interceptorContext.AcceptSubscription) { result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure); @@ -117,9 +119,9 @@ namespace MQTTnet.Server } } - private MqttSubscriptionInterceptorContext InterceptSubscribe(string clientId, TopicFilter topicFilter) + private MqttSubscriptionInterceptorContext InterceptSubscribe(TopicFilter topicFilter) { - var interceptorContext = new MqttSubscriptionInterceptorContext(clientId, topicFilter); + var interceptorContext = new MqttSubscriptionInterceptorContext(_clientId, topicFilter); _options.SubscriptionInterceptor?.Invoke(interceptorContext); return interceptorContext; } diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs index 461842f..ac3b6d1 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs @@ -20,13 +20,9 @@ namespace MQTTnet.Server public MqttServer(IEnumerable adapters, IMqttNetLogger logger) { + if (adapters == null) throw new ArgumentNullException(nameof(adapters)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - if (adapters == null) - { - throw new ArgumentNullException(nameof(adapters)); - } - _adapters = adapters.ToList(); } @@ -40,6 +36,22 @@ namespace MQTTnet.Server return _clientSessionsManager.GetConnectedClientsAsync(); } + public Task SubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + return _clientSessionsManager.SubscribeAsync(clientId, topicFilters); + } + + public Task UnsubscribeAsync(string clientId, IList topicFilters) + { + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + return _clientSessionsManager.UnsubscribeAsync(clientId, topicFilters); + } + public async Task PublishAsync(IEnumerable applicationMessages) { if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs index 50acd61..8b2863a 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs @@ -11,12 +11,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions()); + var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp, "").Wait(); + sm.SubscribeAsync(sp).Wait(); var pp = new MqttApplicationMessage { @@ -32,12 +32,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions()); + var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce)); - sm.SubscribeAsync(sp, "").Wait(); + sm.SubscribeAsync(sp).Wait(); var pp = new MqttApplicationMessage { @@ -53,13 +53,13 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions()); + var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter("#", MqttQualityOfServiceLevel.AtMostOnce)); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtLeastOnce)); - sm.SubscribeAsync(sp, "").Wait(); + sm.SubscribeAsync(sp).Wait(); var pp = new MqttApplicationMessage { @@ -75,12 +75,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions()); + var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp, "").Wait(); + sm.SubscribeAsync(sp).Wait(); var pp = new MqttApplicationMessage { @@ -94,12 +94,12 @@ namespace MQTTnet.Core.Tests [TestMethod] public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() { - var sm = new MqttClientSubscriptionsManager(new MqttServerOptions()); + var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), ""); var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.SubscribeAsync(sp, "").Wait(); + sm.SubscribeAsync(sp).Wait(); var pp = new MqttApplicationMessage {