You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

162 lines
5.8 KiB

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using MQTTnet.Packets;
  7. using MQTTnet.Protocol;
  8. namespace MQTTnet.Server
  9. {
  10. public sealed class MqttClientSubscriptionsManager
  11. {
  12. private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
  13. private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
  14. private readonly IMqttServerOptions _options;
  15. public MqttClientSubscriptionsManager(IMqttServerOptions options)
  16. {
  17. _options = options ?? throw new ArgumentNullException(nameof(options));
  18. }
  19. public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket, string clientId)
  20. {
  21. if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));
  22. var result = new MqttClientSubscribeResult
  23. {
  24. ResponsePacket = subscribePacket.CreateResponse<MqttSubAckPacket>(),
  25. CloseConnection = false
  26. };
  27. await _semaphore.WaitAsync().ConfigureAwait(false);
  28. try
  29. {
  30. foreach (var topicFilter in subscribePacket.TopicFilters)
  31. {
  32. var interceptorContext = InterceptSubscribe(clientId, topicFilter);
  33. if (!interceptorContext.AcceptSubscription)
  34. {
  35. result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure);
  36. }
  37. else
  38. {
  39. result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel));
  40. }
  41. if (interceptorContext.CloseConnection)
  42. {
  43. result.CloseConnection = true;
  44. }
  45. if (interceptorContext.AcceptSubscription)
  46. {
  47. _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
  48. }
  49. }
  50. }
  51. finally
  52. {
  53. _semaphore.Release();
  54. }
  55. return result;
  56. }
  57. public async Task<MqttUnsubAckPacket> UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket)
  58. {
  59. if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket));
  60. await _semaphore.WaitAsync().ConfigureAwait(false);
  61. try
  62. {
  63. foreach (var topicFilter in unsubscribePacket.TopicFilters)
  64. {
  65. _subscriptions.Remove(topicFilter);
  66. }
  67. }
  68. finally
  69. {
  70. _semaphore.Release();
  71. }
  72. return unsubscribePacket.CreateResponse<MqttUnsubAckPacket>();
  73. }
  74. public async Task<CheckSubscriptionsResult> CheckSubscriptionsAsync(MqttApplicationMessage applicationMessage)
  75. {
  76. if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));
  77. await _semaphore.WaitAsync().ConfigureAwait(false);
  78. try
  79. {
  80. var qosLevels = new HashSet<MqttQualityOfServiceLevel>();
  81. foreach (var subscription in _subscriptions)
  82. {
  83. if (!MqttTopicFilterComparer.IsMatch(applicationMessage.Topic, subscription.Key))
  84. {
  85. continue;
  86. }
  87. qosLevels.Add(subscription.Value);
  88. }
  89. if (qosLevels.Count == 0)
  90. {
  91. return new CheckSubscriptionsResult
  92. {
  93. IsSubscribed = false
  94. };
  95. }
  96. return CreateSubscriptionResult(applicationMessage, qosLevels);
  97. }
  98. finally
  99. {
  100. _semaphore.Release();
  101. }
  102. }
  103. private MqttSubscriptionInterceptorContext InterceptSubscribe(string clientId, TopicFilter topicFilter)
  104. {
  105. var interceptorContext = new MqttSubscriptionInterceptorContext(clientId, topicFilter);
  106. _options.SubscriptionInterceptor?.Invoke(interceptorContext);
  107. return interceptorContext;
  108. }
  109. private static CheckSubscriptionsResult CreateSubscriptionResult(MqttApplicationMessage applicationMessage, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels)
  110. {
  111. MqttQualityOfServiceLevel effectiveQoS;
  112. if (subscribedQoSLevels.Contains(applicationMessage.QualityOfServiceLevel))
  113. {
  114. effectiveQoS = applicationMessage.QualityOfServiceLevel;
  115. }
  116. else if (subscribedQoSLevels.Count == 1)
  117. {
  118. effectiveQoS = subscribedQoSLevels.First();
  119. }
  120. else
  121. {
  122. effectiveQoS = subscribedQoSLevels.Max();
  123. }
  124. return new CheckSubscriptionsResult
  125. {
  126. IsSubscribed = true,
  127. QualityOfServiceLevel = effectiveQoS
  128. };
  129. }
  130. private static MqttSubscribeReturnCode ConvertToMaximumQoS(MqttQualityOfServiceLevel qualityOfServiceLevel)
  131. {
  132. switch (qualityOfServiceLevel)
  133. {
  134. case MqttQualityOfServiceLevel.AtMostOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS0;
  135. case MqttQualityOfServiceLevel.AtLeastOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS1;
  136. case MqttQualityOfServiceLevel.ExactlyOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS2;
  137. default: return MqttSubscribeReturnCode.Failure;
  138. }
  139. }
  140. }
  141. }