Browse Source

Add server methods for controlling the client subscriptions.

release/3.x.x
Christian 6 years ago
parent
commit
30754433eb
6 changed files with 101 additions and 24 deletions
  1. +3
    -0
      Frameworks/MQTTnet.NetStandard/Server/IMqttServer.cs
  2. +22
    -4
      Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs
  3. +42
    -0
      Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs
  4. +7
    -5
      Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs
  5. +17
    -5
      Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs
  6. +10
    -10
      Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs

+ 3
- 0
Frameworks/MQTTnet.NetStandard/Server/IMqttServer.cs View File

@@ -12,6 +12,9 @@ namespace MQTTnet.Server


Task<IList<ConnectedMqttClient>> GetConnectedClientsAsync(); Task<IList<ConnectedMqttClient>> GetConnectedClientsAsync();


Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters);
Task UnsubscribeAsync(string clientId, IList<string> topicFilters);

Task StartAsync(IMqttServerOptions options); Task StartAsync(IMqttServerOptions options);
Task StopAsync(); Task StopAsync();
} }

+ 22
- 4
Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -43,7 +44,7 @@ namespace MQTTnet.Server


ClientId = clientId; ClientId = clientId;
_subscriptionsManager = new MqttClientSubscriptionsManager(_options);
_subscriptionsManager = new MqttClientSubscriptionsManager(_options, clientId);
_pendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger); _pendingMessagesQueue = new MqttClientPendingMessagesQueue(_options, this, _logger);
} }


@@ -143,6 +144,24 @@ namespace MQTTnet.Server
_pendingMessagesQueue.Enqueue(publishPacket); _pendingMessagesQueue.Enqueue(publishPacket);
} }


public Task SubscribeAsync(IList<TopicFilter> topicFilters)
{
return _subscriptionsManager.SubscribeAsync(new MqttSubscribePacket
{
TopicFilters = topicFilters
});
}

public Task UnsubscribeAsync(IList<string> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return _subscriptionsManager.UnsubscribeAsync(new MqttUnsubscribePacket
{
TopicFilters = topicFilters
});
}

public void Dispose() public void Dispose()
{ {
_pendingMessagesQueue?.Dispose(); _pendingMessagesQueue?.Dispose();
@@ -231,7 +250,7 @@ namespace MQTTnet.Server


private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
{ {
var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket, ClientId);
var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false);
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, subscribeResult.ResponsePacket).ConfigureAwait(false); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, subscribeResult.ResponsePacket).ConfigureAwait(false);


if (subscribeResult.CloseConnection) if (subscribeResult.CloseConnection)
@@ -245,8 +264,7 @@ namespace MQTTnet.Server


private async Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) private async Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
{ {
var unsubscribeResult = await _subscriptionsManager.UnsubscribeAsync(unsubscribePacket);

var unsubscribeResult = await _subscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false);
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, unsubscribeResult); await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, unsubscribeResult);
} }




+ 42
- 0
Frameworks/MQTTnet.NetStandard/Server/MqttClientSessionsManager.cs View File

@@ -168,6 +168,48 @@ namespace MQTTnet.Server
} }
} }


public async Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

await _semaphore.WaitAsync().ConfigureAwait(false);
try
{
if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session {clientId} is unknown.");
}

await session.SubscribeAsync(topicFilters);
}
finally
{
_semaphore.Release();
}
}

public async Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

await _semaphore.WaitAsync().ConfigureAwait(false);
try
{
if (!_sessions.TryGetValue(clientId, out var session))
{
throw new InvalidOperationException($"Client session {clientId} is unknown.");
}

await session.UnsubscribeAsync(topicFilters);
}
finally
{
_semaphore.Release();
}
}

private MqttApplicationMessage InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage) private MqttApplicationMessage InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
{ {
if (_options.ApplicationMessageInterceptor == null) if (_options.ApplicationMessageInterceptor == null)


+ 7
- 5
Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs View File

@@ -13,13 +13,15 @@ namespace MQTTnet.Server
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>(); private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
private readonly IMqttServerOptions _options; private readonly IMqttServerOptions _options;
private readonly string _clientId;


public MqttClientSubscriptionsManager(IMqttServerOptions options)
public MqttClientSubscriptionsManager(IMqttServerOptions options, string clientId)
{ {
_options = options ?? throw new ArgumentNullException(nameof(options)); _options = options ?? throw new ArgumentNullException(nameof(options));
_clientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
} }


public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket, string clientId)
public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket)
{ {
if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));


@@ -34,7 +36,7 @@ namespace MQTTnet.Server
{ {
foreach (var topicFilter in subscribePacket.TopicFilters) foreach (var topicFilter in subscribePacket.TopicFilters)
{ {
var interceptorContext = InterceptSubscribe(clientId, topicFilter);
var interceptorContext = InterceptSubscribe(topicFilter);
if (!interceptorContext.AcceptSubscription) if (!interceptorContext.AcceptSubscription)
{ {
result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure); result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure);
@@ -117,9 +119,9 @@ namespace MQTTnet.Server
} }
} }


private MqttSubscriptionInterceptorContext InterceptSubscribe(string clientId, TopicFilter topicFilter)
private MqttSubscriptionInterceptorContext InterceptSubscribe(TopicFilter topicFilter)
{ {
var interceptorContext = new MqttSubscriptionInterceptorContext(clientId, topicFilter);
var interceptorContext = new MqttSubscriptionInterceptorContext(_clientId, topicFilter);
_options.SubscriptionInterceptor?.Invoke(interceptorContext); _options.SubscriptionInterceptor?.Invoke(interceptorContext);
return interceptorContext; return interceptorContext;
} }


+ 17
- 5
Frameworks/MQTTnet.NetStandard/Server/MqttServer.cs View File

@@ -20,13 +20,9 @@ namespace MQTTnet.Server


public MqttServer(IEnumerable<IMqttServerAdapter> adapters, IMqttNetLogger logger) public MqttServer(IEnumerable<IMqttServerAdapter> adapters, IMqttNetLogger logger)
{ {
if (adapters == null) throw new ArgumentNullException(nameof(adapters));
_logger = logger ?? throw new ArgumentNullException(nameof(logger)); _logger = logger ?? throw new ArgumentNullException(nameof(logger));


if (adapters == null)
{
throw new ArgumentNullException(nameof(adapters));
}

_adapters = adapters.ToList(); _adapters = adapters.ToList();
} }


@@ -40,6 +36,22 @@ namespace MQTTnet.Server
return _clientSessionsManager.GetConnectedClientsAsync(); return _clientSessionsManager.GetConnectedClientsAsync();
} }


public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return _clientSessionsManager.SubscribeAsync(clientId, topicFilters);
}

public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
{
if (clientId == null) throw new ArgumentNullException(nameof(clientId));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return _clientSessionsManager.UnsubscribeAsync(clientId, topicFilters);
}

public async Task PublishAsync(IEnumerable<MqttApplicationMessage> applicationMessages) public async Task PublishAsync(IEnumerable<MqttApplicationMessage> applicationMessages)
{ {
if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages)); if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages));


+ 10
- 10
Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs View File

@@ -11,12 +11,12 @@ namespace MQTTnet.Core.Tests
[TestMethod] [TestMethod]
public void MqttSubscriptionsManager_SubscribeSingleSuccess() public void MqttSubscriptionsManager_SubscribeSingleSuccess()
{ {
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions());
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), "");


var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());


sm.SubscribeAsync(sp, "").Wait();
sm.SubscribeAsync(sp).Wait();


var pp = new MqttApplicationMessage var pp = new MqttApplicationMessage
{ {
@@ -32,12 +32,12 @@ namespace MQTTnet.Core.Tests
[TestMethod] [TestMethod]
public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess() public void MqttSubscriptionsManager_SubscribeDifferentQoSSuccess()
{ {
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions());
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), "");


var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce)); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce));


sm.SubscribeAsync(sp, "").Wait();
sm.SubscribeAsync(sp).Wait();


var pp = new MqttApplicationMessage var pp = new MqttApplicationMessage
{ {
@@ -53,13 +53,13 @@ namespace MQTTnet.Core.Tests
[TestMethod] [TestMethod]
public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess() public void MqttSubscriptionsManager_SubscribeTwoTimesSuccess()
{ {
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions());
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), "");


var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter("#", MqttQualityOfServiceLevel.AtMostOnce)); sp.TopicFilters.Add(new TopicFilter("#", MqttQualityOfServiceLevel.AtMostOnce));
sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtLeastOnce)); sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtLeastOnce));


sm.SubscribeAsync(sp, "").Wait();
sm.SubscribeAsync(sp).Wait();


var pp = new MqttApplicationMessage var pp = new MqttApplicationMessage
{ {
@@ -75,12 +75,12 @@ namespace MQTTnet.Core.Tests
[TestMethod] [TestMethod]
public void MqttSubscriptionsManager_SubscribeSingleNoSuccess() public void MqttSubscriptionsManager_SubscribeSingleNoSuccess()
{ {
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions());
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), "");


var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());


sm.SubscribeAsync(sp, "").Wait();
sm.SubscribeAsync(sp).Wait();


var pp = new MqttApplicationMessage var pp = new MqttApplicationMessage
{ {
@@ -94,12 +94,12 @@ namespace MQTTnet.Core.Tests
[TestMethod] [TestMethod]
public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle() public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle()
{ {
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions());
var sm = new MqttClientSubscriptionsManager(new MqttServerOptions(), "");


var sp = new MqttSubscribePacket(); var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());


sm.SubscribeAsync(sp, "").Wait();
sm.SubscribeAsync(sp).Wait();


var pp = new MqttApplicationMessage var pp = new MqttApplicationMessage
{ {


Loading…
Cancel
Save