ソースを参照

Add interceptor for client subscriptions

release/3.x.x
Christian Kratky 7年前
コミット
ba16ae6568
11個のファイルの変更140行の追加66行の削除
  1. +30
    -29
      MQTTnet.Core/Client/MqttClient.cs
  2. +7
    -0
      MQTTnet.Core/Server/MqttApplicationMessageInterceptorContext.cs
  3. +28
    -14
      MQTTnet.Core/Server/MqttClientSession.cs
  4. +11
    -0
      MQTTnet.Core/Server/MqttClientSubscribeResult.cs
  5. +23
    -7
      MQTTnet.Core/Server/MqttClientSubscriptionsManager.cs
  6. +8
    -2
      MQTTnet.Core/Server/MqttServer.cs
  7. +3
    -1
      MQTTnet.Core/Server/MqttServerOptions.cs
  8. +19
    -0
      MQTTnet.Core/Server/MqttSubscriptionInterceptorContext.cs
  9. +2
    -3
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs
  10. +5
    -4
      Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs
  11. +4
    -6
      Tests/MQTTnet.TestApp.NetCore/ServerTest.cs

+ 30
- 29
MQTTnet.Core/Client/MqttClient.cs ファイルの表示

@@ -149,36 +149,36 @@ namespace MQTTnet.Core.Client
switch (qosGroup.Key)
{
case MqttQualityOfServiceLevel.AtMostOnce:
{
// No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier]
await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, qosPackets);
break;
}
{
// No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier]
await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, qosPackets);
break;
}
case MqttQualityOfServiceLevel.AtLeastOnce:
{
foreach (var publishPacket in qosPackets)
{
foreach (var publishPacket in qosPackets)
{
publishPacket.PacketIdentifier = GetNewPacketIdentifier();
await SendAndReceiveAsync<MqttPubAckPacket>(publishPacket);
}

break;
publishPacket.PacketIdentifier = GetNewPacketIdentifier();
await SendAndReceiveAsync<MqttPubAckPacket>(publishPacket);
}

break;
}
case MqttQualityOfServiceLevel.ExactlyOnce:
{
foreach (var publishPacket in qosPackets)
{
foreach (var publishPacket in qosPackets)
{
publishPacket.PacketIdentifier = GetNewPacketIdentifier();
var pubRecPacket = await SendAndReceiveAsync<MqttPubRecPacket>(publishPacket).ConfigureAwait(false);
await SendAndReceiveAsync<MqttPubCompPacket>(pubRecPacket.CreateResponse<MqttPubRelPacket>()).ConfigureAwait(false);
}

break;
publishPacket.PacketIdentifier = GetNewPacketIdentifier();
var pubRecPacket = await SendAndReceiveAsync<MqttPubRecPacket>(publishPacket).ConfigureAwait(false);
await SendAndReceiveAsync<MqttPubCompPacket>(pubRecPacket.CreateResponse<MqttPubRelPacket>()).ConfigureAwait(false);
}

break;
}
default:
{
throw new InvalidOperationException();
}
{
throw new InvalidOperationException();
}
}
}
}
@@ -191,7 +191,7 @@ namespace MQTTnet.Core.Client
Username = _options.Credentials?.Username,
Password = _options.Credentials?.Password,
CleanSession = _options.CleanSession,
KeepAlivePeriod = (ushort)_options.KeepAlivePeriod.TotalSeconds,
KeepAlivePeriod = (ushort) _options.KeepAlivePeriod.TotalSeconds,
WillMessage = willApplicationMessage
};

@@ -324,7 +324,7 @@ namespace MQTTnet.Core.Client
if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce)
{
FireApplicationMessageReceivedEvent(publishPacket);
await SendAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier });
await SendAsync(new MqttPubAckPacket {PacketIdentifier = publishPacket.PacketIdentifier});
return;
}

@@ -337,7 +337,7 @@ namespace MQTTnet.Core.Client
}

FireApplicationMessageReceivedEvent(publishPacket);
await SendAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier });
await SendAsync(new MqttPubRecPacket {PacketIdentifier = publishPacket.PacketIdentifier});
return;
}

@@ -363,12 +363,12 @@ namespace MQTTnet.Core.Client
{
var packetAwaiter = _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.CommunicationTimeout);
await _adapter.SendPacketsAsync(_options.CommunicationTimeout, _cancellationTokenSource.Token, requestPacket).ConfigureAwait(false);
return (TResponsePacket)await packetAwaiter.ConfigureAwait(false);
return (TResponsePacket) await packetAwaiter.ConfigureAwait(false);
}

private ushort GetNewPacketIdentifier()
{
return (ushort)Interlocked.Increment(ref _latestPacketIdentifier);
return (ushort) Interlocked.Increment(ref _latestPacketIdentifier);
}

private async Task SendKeepAliveMessagesAsync(CancellationToken cancellationToken)
@@ -465,7 +465,8 @@ namespace MQTTnet.Core.Client
private void StartSendKeepAliveMessages(CancellationToken cancellationToken)
{
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
Task.Factory.StartNew(() => SendKeepAliveMessagesAsync(cancellationToken), cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default).ConfigureAwait(false);
Task.Factory.StartNew(() => SendKeepAliveMessagesAsync(cancellationToken), cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default)
.ConfigureAwait(false);
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
}
}

+ 7
- 0
MQTTnet.Core/Server/MqttApplicationMessageInterceptorContext.cs ファイルの表示

@@ -0,0 +1,7 @@
namespace MQTTnet.Core.Server
{
public class MqttApplicationMessageInterceptorContext
{
public MqttApplicationMessage ApplicationMessage { get; set; }
}
}

+ 28
- 14
MQTTnet.Core/Server/MqttClientSession.cs ファイルの表示

@@ -124,7 +124,7 @@ namespace MQTTnet.Core.Server
while (!cancellationToken.IsCancellationRequested)
{
var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false);
await ProcessReceivedPacketAsync(adapter, packet).ConfigureAwait(false);
await ProcessReceivedPacketAsync(adapter, packet, cancellationToken).ConfigureAwait(false);
}
}
catch (OperationCanceledException)
@@ -142,28 +142,35 @@ namespace MQTTnet.Core.Server
}
}

private async Task ProcessReceivedPacketAsync(IMqttCommunicationAdapter adapter, MqttBasePacket packet)
private async Task ProcessReceivedPacketAsync(IMqttCommunicationAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
{
if (packet is MqttSubscribePacket subscribePacket)
{
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Subscribe(subscribePacket));
var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket);
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, subscribeResult.ResponsePacket);
EnqueueRetainedMessages(subscribePacket);

if (subscribeResult.CloseConnection)
{
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttDisconnectPacket());
Stop();
}
}
else if (packet is MqttUnsubscribePacket unsubscribePacket)
{
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Unsubscribe(unsubscribePacket));
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, _subscriptionsManager.Unsubscribe(unsubscribePacket));
}
else if (packet is MqttPublishPacket publishPacket)
{
await HandleIncomingPublishPacketAsync(adapter, publishPacket);
await HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken);
}
else if (packet is MqttPubRelPacket pubRelPacket)
{
await HandleIncomingPubRelPacketAsync(adapter, pubRelPacket);
await HandleIncomingPubRelPacketAsync(adapter, pubRelPacket, cancellationToken);
}
else if (packet is MqttPubRecPacket pubRecPacket)
{
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, pubRecPacket.CreateResponse<MqttPubRelPacket>());
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, pubRecPacket.CreateResponse<MqttPubRelPacket>());
}
else if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
{
@@ -171,7 +178,7 @@ namespace MQTTnet.Core.Server
}
else if (packet is MqttPingReqPacket)
{
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPingRespPacket());
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPingRespPacket());
}
else if (packet is MqttDisconnectPacket || packet is MqttConnectPacket)
{
@@ -193,10 +200,17 @@ namespace MQTTnet.Core.Server
}
}

private async Task HandleIncomingPublishPacketAsync(IMqttCommunicationAdapter adapter, MqttPublishPacket publishPacket)
private async Task HandleIncomingPublishPacketAsync(IMqttCommunicationAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken)
{
var applicationMessage = publishPacket.ToApplicationMessage();
_options.ApplicationMessageInterceptor?.Invoke(applicationMessage);

var interceptorContext = new MqttApplicationMessageInterceptorContext
{
ApplicationMessage = applicationMessage
};

_options.ApplicationMessageInterceptor?.Invoke(interceptorContext);
applicationMessage = interceptorContext.ApplicationMessage;

if (applicationMessage.Retain)
{
@@ -214,7 +228,7 @@ namespace MQTTnet.Core.Server
{
_sessionsManager.DispatchApplicationMessage(this, applicationMessage);

await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token,
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken,
new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier });

return;
@@ -229,7 +243,7 @@ namespace MQTTnet.Core.Server

_sessionsManager.DispatchApplicationMessage(this, applicationMessage);

await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token,
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken,
new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier });

return;
@@ -239,14 +253,14 @@ namespace MQTTnet.Core.Server
}
}

private Task HandleIncomingPubRelPacketAsync(IMqttCommunicationAdapter adapter, MqttPubRelPacket pubRelPacket)
private Task HandleIncomingPubRelPacketAsync(IMqttCommunicationAdapter adapter, MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken)
{
lock (_unacknowledgedPublishPackets)
{
_unacknowledgedPublishPackets.Remove(pubRelPacket.PacketIdentifier);
}

return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier });
return adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier });
}
}
}

+ 11
- 0
MQTTnet.Core/Server/MqttClientSubscribeResult.cs ファイルの表示

@@ -0,0 +1,11 @@
using MQTTnet.Core.Packets;

namespace MQTTnet.Core.Server
{
public class MqttClientSubscribeResult
{
public MqttSubAckPacket ResponsePacket { get; set; }

public bool CloseConnection { get; set; }
}
}

+ 23
- 7
MQTTnet.Core/Server/MqttClientSubscriptionsManager.cs ファイルの表示

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using Microsoft.Extensions.Options;
using MQTTnet.Core.Packets;
using MQTTnet.Core.Protocol;

@@ -8,30 +9,45 @@ namespace MQTTnet.Core.Server
public sealed class MqttClientSubscriptionsManager
{
private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscribedTopics = new Dictionary<string, MqttQualityOfServiceLevel>();
private readonly MqttServerOptions _options;

public MqttClientSubscriptionsManager()
public MqttClientSubscriptionsManager(IOptions<MqttServerOptions> options)
{
_options = options?.Value ?? throw new ArgumentNullException(nameof(options));
}

public MqttSubAckPacket Subscribe(MqttSubscribePacket subscribePacket)
public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket)
{
if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));

var responsePacket = subscribePacket.CreateResponse<MqttSubAckPacket>();
var closeConnection = false;

lock (_subscribedTopics)
{
foreach (var topicFilter in subscribePacket.TopicFilters)
{
var interceptorContext = new MqttSubscriptionInterceptorContext("", topicFilter);
_options.SubscriptionsInterceptor?.Invoke(interceptorContext);
responsePacket.SubscribeReturnCodes.Add(interceptorContext.AcceptSubscription ? MqttSubscribeReturnCode.SuccessMaximumQoS1 : MqttSubscribeReturnCode.Failure);
if (interceptorContext.CloseConnection)
{
closeConnection = true;
}


_subscribedTopics[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
responsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.SuccessMaximumQoS1); // TODO: Add support for QoS 2.
if (interceptorContext.AcceptSubscription)
{
_subscribedTopics[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
}
}
}

return responsePacket;
return new MqttClientSubscribeResult
{
ResponsePacket = responsePacket,
CloseConnection = closeConnection
};
}

public MqttUnsubAckPacket Unsubscribe(MqttUnsubscribePacket unsubscribePacket)


+ 8
- 2
MQTTnet.Core/Server/MqttServer.cs ファイルの表示

@@ -56,8 +56,14 @@ namespace MQTTnet.Core.Server

foreach (var applicationMessage in applicationMessages)
{
_options.ApplicationMessageInterceptor?.Invoke(applicationMessage);
_clientSessionsManager.DispatchApplicationMessage(null, applicationMessage);
var interceptorContext = new MqttApplicationMessageInterceptorContext
{
ApplicationMessage = applicationMessage
};

_options.ApplicationMessageInterceptor?.Invoke(interceptorContext);
_clientSessionsManager.DispatchApplicationMessage(null, interceptorContext.ApplicationMessage);
}
}



+ 3
- 1
MQTTnet.Core/Server/MqttServerOptions.cs ファイルの表示

@@ -16,7 +16,9 @@ namespace MQTTnet.Core.Server

public Func<MqttConnectPacket, MqttConnectReturnCode> ConnectionValidator { get; set; }

public Func<MqttApplicationMessage, MqttApplicationMessage> ApplicationMessageInterceptor { get; set; }
public Action<MqttApplicationMessageInterceptorContext> ApplicationMessageInterceptor { get; set; }

public Action<MqttSubscriptionInterceptorContext> SubscriptionsInterceptor { get; set; }

public IMqttServerStorage Storage { get; set; }
}


+ 19
- 0
MQTTnet.Core/Server/MqttSubscriptionInterceptorContext.cs ファイルの表示

@@ -0,0 +1,19 @@
namespace MQTTnet.Core.Server
{
public class MqttSubscriptionInterceptorContext
{
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter)
{
ClientId = clientId;
TopicFilter = topicFilter;
}

public string ClientId { get; }

public TopicFilter TopicFilter { get; }

public bool AcceptSubscription { get; set; } = true;

public bool CloseConnection { get; set; }
}
}

+ 2
- 3
Tests/MQTTnet.Core.Tests/MqttServerTests.cs ファイルの表示

@@ -245,10 +245,9 @@ namespace MQTTnet.Core.Tests
[TestMethod]
public async Task MqttServer_InterceptMessage()
{
MqttApplicationMessage Interceptor(MqttApplicationMessage message)
void Interceptor(MqttApplicationMessageInterceptorContext context)
{
message.Payload = Encoding.ASCII.GetBytes("extended");
return message;
context.ApplicationMessage.Payload = Encoding.ASCII.GetBytes("extended");
}

var serverAdapter = new TestMqttServerAdapter();


+ 5
- 4
Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs ファイルの表示

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.Extensions.Options;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Core.Packets;
using MQTTnet.Core.Protocol;
using MQTTnet.Core.Server;
@@ -11,7 +12,7 @@ namespace MQTTnet.Core.Tests
[TestMethod]
public void MqttSubscriptionsManager_SubscribeSingleSuccess()
{
var sm = new MqttClientSubscriptionsManager();
var sm = new MqttClientSubscriptionsManager(new OptionsWrapper<MqttServerOptions>(new MqttServerOptions()));

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce));
@@ -30,7 +31,7 @@ namespace MQTTnet.Core.Tests
[TestMethod]
public void MqttSubscriptionsManager_SubscribeSingleNoSuccess()
{
var sm = new MqttClientSubscriptionsManager();
var sm = new MqttClientSubscriptionsManager(new OptionsWrapper<MqttServerOptions>(new MqttServerOptions()));

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce));
@@ -49,7 +50,7 @@ namespace MQTTnet.Core.Tests
[TestMethod]
public void MqttSubscriptionsManager_SubscribeAndUnsubscribeSingle()
{
var sm = new MqttClientSubscriptionsManager();
var sm = new MqttClientSubscriptionsManager(new OptionsWrapper<MqttServerOptions>(new MqttServerOptions()));

var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce));


+ 4
- 6
Tests/MQTTnet.TestApp.NetCore/ServerTest.cs ファイルの表示

@@ -35,16 +35,14 @@ namespace MQTTnet.TestApp.NetCore

options.Storage = new RetainedMessageHandler();

options.ApplicationMessageInterceptor = message =>
options.ApplicationMessageInterceptor = context =>
{
if (MqttTopicFilterComparer.IsMatch(message.Topic, "/myTopic/WithTimestamp/#"))
if (MqttTopicFilterComparer.IsMatch(context.ApplicationMessage.Topic, "/myTopic/WithTimestamp/#"))
{
// Replace the payload with the timestamp. But also extending a JSON
// based payload with the timestamp is a suitable use case.
message.Payload = Encoding.UTF8.GetBytes(DateTime.Now.ToString("O"));
}

return message;
context.ApplicationMessage.Payload = Encoding.UTF8.GetBytes(DateTime.Now.ToString("O"));
}
};
});



読み込み中…
キャンセル
保存