瀏覽代碼

Remove dedicated values from contexts and only provide session items.

release/3.x.x
Christian Kratky 5 年之前
父節點
當前提交
6c8db47e25
共有 11 個檔案被更改,包括 145 行新增88 行删除
  1. +3
    -6
      Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs
  2. +7
    -0
      Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs
  3. +6
    -4
      Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs
  4. +8
    -3
      Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs
  5. +0
    -54
      Source/MQTTnet/Server/MqttBaseInterceptorContext.cs
  6. +1
    -1
      Source/MQTTnet/Server/MqttClientSession.cs
  7. +2
    -5
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  8. +5
    -5
      Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs
  9. +42
    -4
      Source/MQTTnet/Server/MqttConnectionValidatorContext.cs
  10. +10
    -6
      Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs
  11. +61
    -0
      Tests/MQTTnet.Core.Tests/Session_Tests.cs

+ 3
- 6
Source/MQTTnet.Server/Mqtt/MqttApplicationMessageInterceptor.cs 查看文件

@@ -22,16 +22,13 @@ namespace MQTTnet.Server.Mqtt
{
try
{
var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey];

var pythonContext = new PythonDictionary
{
{ "client_id", context.ClientId },
{ "session_items", sessionItems },
{ "retain", context.ApplicationMessage.Retain },
{ "username", context.Username },
{ "password", context.Password },
{ "raw_password", new Bytes(context.RawPassword ?? new byte[0]) },
{ "clean_session", context.CleanSession},
{ "authentication_method", context.AuthenticationMethod},
{ "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) },
{ "accept_publish", context.AcceptPublish },
{ "close_connection", context.CloseConnection },
{ "topic", context.ApplicationMessage.Topic },


+ 7
- 0
Source/MQTTnet.Server/Mqtt/MqttServerConnectionValidator.cs 查看文件

@@ -9,6 +9,8 @@ namespace MQTTnet.Server.Mqtt
{
public class MqttServerConnectionValidator : IMqttServerConnectionValidator
{
public const string WrappedSessionItemsKey = "WRAPPED_ITEMS";

private readonly PythonScriptHostService _pythonScriptHostService;
private readonly ILogger _logger;

@@ -22,6 +24,8 @@ namespace MQTTnet.Server.Mqtt
{
try
{
var sessionItems = new PythonDictionary();

var pythonContext = new PythonDictionary
{
{ "endpoint", context.Endpoint },
@@ -33,6 +37,7 @@ namespace MQTTnet.Server.Mqtt
{ "clean_session", context.CleanSession},
{ "authentication_method", context.AuthenticationMethod},
{ "authentication_data", new Bytes(context.AuthenticationData ?? new byte[0]) },
{ "session_items", sessionItems },

{ "result", PythonConvert.Pythonfy(context.ReasonCode) }
};
@@ -40,6 +45,8 @@ namespace MQTTnet.Server.Mqtt
_pythonScriptHostService.InvokeOptionalFunction("on_validate_client_connection", pythonContext);

context.ReasonCode = PythonConvert.ParseEnum<MqttConnectReasonCode>((string)pythonContext["result"]);

context.SessionItems[WrappedSessionItemsKey] = sessionItems;
}
catch (Exception exception)
{


+ 6
- 4
Source/MQTTnet.Server/Mqtt/MqttSubscriptionInterceptor.cs 查看文件

@@ -21,14 +21,16 @@ namespace MQTTnet.Server.Mqtt
{
try
{
var sessionItems = (PythonDictionary)context.SessionItems[MqttServerConnectionValidator.WrappedSessionItemsKey];

var pythonContext = new PythonDictionary
{
{ "accept_subscription", context.AcceptSubscription },
{ "close_connection", context.CloseConnection },

{ "client_id", context.ClientId },
{ "session_items", sessionItems },
{ "topic", context.TopicFilter.Topic },
{ "qos", (int)context.TopicFilter.QualityOfServiceLevel }
{ "qos", (int)context.TopicFilter.QualityOfServiceLevel },
{ "accept_subscription", context.AcceptSubscription },
{ "close_connection", context.CloseConnection }
};

_pythonScriptHostService.InvokeOptionalFunction("on_intercept_subscription", pythonContext);


+ 8
- 3
Source/MQTTnet/Server/MqttApplicationMessageInterceptorContext.cs 查看文件

@@ -1,20 +1,25 @@
using System.Collections.Generic;
using MQTTnet.Packets;

namespace MQTTnet.Server
{
public class MqttApplicationMessageInterceptorContext : MqttBaseInterceptorContext
public class MqttApplicationMessageInterceptorContext
{
public MqttApplicationMessageInterceptorContext(string clientId, IDictionary<object, object> sessionItems, MqttConnectPacket connectPacket, MqttApplicationMessage applicationMessage) : base(connectPacket, sessionItems)
public MqttApplicationMessageInterceptorContext(string clientId, IDictionary<object, object> sessionItems, MqttApplicationMessage applicationMessage)
{
ClientId = clientId;
ApplicationMessage = applicationMessage;
SessionItems = sessionItems;
}

public string ClientId { get; }

public MqttApplicationMessage ApplicationMessage { get; set; }

/// <summary>
/// Gets or sets a key/value collection that can be used to share data within the scope of this session.
/// </summary>
public IDictionary<object, object> SessionItems { get; }

public bool AcceptPublish { get; set; } = true;

public bool CloseConnection { get; set; }


+ 0
- 54
Source/MQTTnet/Server/MqttBaseInterceptorContext.cs 查看文件

@@ -1,54 +0,0 @@
using System.Collections.Generic;
using System.Text;
using MQTTnet.Packets;

namespace MQTTnet.Server
{
public class MqttBaseInterceptorContext
{
private readonly MqttConnectPacket _connectPacket;

protected MqttBaseInterceptorContext(MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems)
{
_connectPacket = connectPacket;
SessionItems = sessionItems;
}

public string Username => _connectPacket?.Username;

public byte[] RawPassword => _connectPacket?.Password;

public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]);

public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage;

public bool? CleanSession => _connectPacket?.CleanSession;

public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod;

public List<MqttUserProperty> UserProperties => _connectPacket?.Properties?.UserProperties;

public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData;

public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod;

public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize;

public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum;

public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum;

public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation;

public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation;

public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval;

public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval;

/// <summary>
/// Gets or sets a key/value collection that can be used to share data within the scope of this session.
/// </summary>
public IDictionary<object, object> SessionItems { get; }
}
}

+ 1
- 1
Source/MQTTnet/Server/MqttClientSession.cs 查看文件

@@ -54,7 +54,7 @@ namespace MQTTnet.Server

public async Task SubscribeAsync(ICollection<TopicFilter> topicFilters, MqttRetainedMessagesManager retainedMessagesManager)
{
await SubscriptionsManager.SubscribeAsync(topicFilters, null).ConfigureAwait(false);
await SubscriptionsManager.SubscribeAsync(topicFilters).ConfigureAwait(false);

var matchingRetainedMessages = await retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false);
foreach (var matchingRetainedMessage in matchingRetainedMessages)


+ 2
- 5
Source/MQTTnet/Server/MqttClientSessionsManager.cs 查看文件

@@ -290,7 +290,7 @@ namespace MQTTnet.Server

private async Task<MqttConnectionValidatorContext> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
{
var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter);
var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary<object, object>());

var connectionValidator = _options.ConnectionValidator;

@@ -372,23 +372,20 @@ namespace MQTTnet.Server

string senderClientId;
IDictionary<object, object> sessionItems;
MqttConnectPacket connectPacket;

var messageIsFromServer = senderConnection == null;
if (messageIsFromServer)
{
senderClientId = _options.ClientId;
sessionItems = _serverSessionItems;
connectPacket = null;
}
else
{
senderClientId = senderConnection.ClientId;
sessionItems = senderConnection.Session.Items;
connectPacket = senderConnection.ConnectPacket;
}

var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, connectPacket, applicationMessage);
var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, applicationMessage);
await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
return interceptorContext;
}


+ 5
- 5
Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs 查看文件

@@ -40,7 +40,7 @@ namespace MQTTnet.Server

foreach (var originalTopicFilter in subscribePacket.TopicFilters)
{
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter, connectPacket).ConfigureAwait(false);
var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter).ConfigureAwait(false);

var finalTopicFilter = interceptorContext.TopicFilter;

@@ -74,13 +74,13 @@ namespace MQTTnet.Server
return result;
}

public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters, MqttConnectPacket connectPacket)
public async Task SubscribeAsync(IEnumerable<TopicFilter> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

foreach (var topicFilter in topicFilters)
{
var interceptorContext = await InterceptSubscribeAsync(topicFilter, connectPacket).ConfigureAwait(false);
var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false);
if (!interceptorContext.AcceptSubscription)
{
continue;
@@ -195,9 +195,9 @@ namespace MQTTnet.Server
}
}

private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter, MqttConnectPacket connectPacket)
private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(TopicFilter topicFilter)
{
var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, connectPacket, _clientSession.Items);
var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items);
if (_serverOptions.SubscriptionInterceptor != null)
{
await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false);


+ 42
- 4
Source/MQTTnet/Server/MqttConnectionValidatorContext.cs 查看文件

@@ -1,7 +1,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using MQTTnet.Adapter;
using MQTTnet.Formatter;
using MQTTnet.Packets;
@@ -9,15 +9,16 @@ using MQTTnet.Protocol;

namespace MQTTnet.Server
{
public class MqttConnectionValidatorContext : MqttBaseInterceptorContext
public class MqttConnectionValidatorContext
{
private readonly MqttConnectPacket _connectPacket;
private readonly IMqttChannelAdapter _clientAdapter;

public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) : base(connectPacket, new ConcurrentDictionary<object, object>())
public MqttConnectionValidatorContext(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter, IDictionary<object, object> sessionItems)
{
_connectPacket = connectPacket;
_clientAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter));
SessionItems = sessionItems;
}

public string ClientId => _connectPacket.ClientId;
@@ -29,7 +30,44 @@ namespace MQTTnet.Server
public X509Certificate2 ClientCertificate => _clientAdapter.ClientCertificate;

public MqttProtocolVersion ProtocolVersion => _clientAdapter.PacketFormatterAdapter.ProtocolVersion;

public string Username => _connectPacket?.Username;

public byte[] RawPassword => _connectPacket?.Password;

public string Password => Encoding.UTF8.GetString(RawPassword ?? new byte[0]);

public MqttApplicationMessage WillMessage => _connectPacket?.WillMessage;

public bool? CleanSession => _connectPacket?.CleanSession;

public ushort? KeepAlivePeriod => _connectPacket?.KeepAlivePeriod;

public List<MqttUserProperty> UserProperties => _connectPacket?.Properties?.UserProperties;

public byte[] AuthenticationData => _connectPacket?.Properties?.AuthenticationData;

public string AuthenticationMethod => _connectPacket?.Properties?.AuthenticationMethod;

public uint? MaximumPacketSize => _connectPacket?.Properties?.MaximumPacketSize;

public ushort? ReceiveMaximum => _connectPacket?.Properties?.ReceiveMaximum;

public ushort? TopicAliasMaximum => _connectPacket?.Properties?.TopicAliasMaximum;

public bool? RequestProblemInformation => _connectPacket?.Properties?.RequestProblemInformation;

public bool? RequestResponseInformation => _connectPacket?.Properties?.RequestResponseInformation;

public uint? SessionExpiryInterval => _connectPacket?.Properties?.SessionExpiryInterval;

public uint? WillDelayInterval => _connectPacket?.Properties?.WillDelayInterval;

/// <summary>
/// Gets or sets a key/value collection that can be used to share data within the scope of this session.
/// </summary>
public IDictionary<object, object> SessionItems { get; }

/// <summary>
/// This is used for MQTTv3 only.
/// </summary>


+ 10
- 6
Source/MQTTnet/Server/MqttSubscriptionInterceptorContext.cs 查看文件

@@ -1,21 +1,25 @@
using System;
using System.Collections.Generic;
using MQTTnet.Packets;
using System.Collections.Generic;

namespace MQTTnet.Server
{
public class MqttSubscriptionInterceptorContext : MqttBaseInterceptorContext
public class MqttSubscriptionInterceptorContext
{
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, MqttConnectPacket connectPacket, IDictionary<object, object> sessionItems) : base(connectPacket, sessionItems)
public MqttSubscriptionInterceptorContext(string clientId, TopicFilter topicFilter, IDictionary<object, object> sessionItems)
{
ClientId = clientId;
TopicFilter = topicFilter ?? throw new ArgumentNullException(nameof(topicFilter));
TopicFilter = topicFilter;
SessionItems = sessionItems;
}

public string ClientId { get; }

public TopicFilter TopicFilter { get; set; }

/// <summary>
/// Gets or sets a key/value collection that can be used to share data within the scope of this session.
/// </summary>
public IDictionary<object, object> SessionItems { get; }

public bool AcceptSubscription { get; set; } = true;

public bool CloseConnection { get; set; }


+ 61
- 0
Tests/MQTTnet.Core.Tests/Session_Tests.cs 查看文件

@@ -0,0 +1,61 @@
using System.Text;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Client;
using MQTTnet.Client.Subscribing;
using MQTTnet.Server;
using MQTTnet.Tests.Mockups;

namespace MQTTnet.Tests
{
[TestClass]
public class Session_Tests
{
[TestMethod]
public async Task Set_Session_Item()
{
using (var testEnvironment = new TestEnvironment())
{
var serverOptions = new MqttServerOptionsBuilder()
.WithConnectionValidator(delegate (MqttConnectionValidatorContext context)
{
// Don't validate anything. Just set some session items.
context.SessionItems["can_subscribe_x"] = true;
context.SessionItems["default_payload"] = "Hello World";
})
.WithSubscriptionInterceptor(delegate (MqttSubscriptionInterceptorContext context)
{
if (context.TopicFilter.Topic == "x")
{
context.AcceptSubscription = context.SessionItems["can_subscribe_x"] as bool? == true;
}
})
.WithApplicationMessageInterceptor(delegate (MqttApplicationMessageInterceptorContext context)
{
context.ApplicationMessage.Payload = Encoding.UTF8.GetBytes(context.SessionItems["default_payload"] as string);
});

await testEnvironment.StartServerAsync(serverOptions);

string receivedPayload = null;

var client = await testEnvironment.ConnectClientAsync();
client.UseApplicationMessageReceivedHandler(delegate(MqttApplicationMessageReceivedEventArgs args)
{
receivedPayload = args.ApplicationMessage.ConvertPayloadToString();
});

var subscribeResult = await client.SubscribeAsync("x");

Assert.AreEqual(MqttClientSubscribeResultCode.GrantedQoS0, subscribeResult.Items[0].ResultCode);

var client2 = await testEnvironment.ConnectClientAsync();
await client2.PublishAsync("x");

await Task.Delay(1000);

Assert.AreEqual("Hello World", receivedPayload);
}
}
}
}

Loading…
取消
儲存