You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

181 line
6.5 KiB

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using MQTTnet.Packets;
  7. using MQTTnet.Protocol;
  8. namespace MQTTnet.Server
  9. {
  10. public sealed class MqttClientSubscriptionsManager : IDisposable
  11. {
  12. private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
  13. private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
  14. private readonly IMqttServerOptions _options;
  15. private readonly string _clientId;
  16. public MqttClientSubscriptionsManager(IMqttServerOptions options, string clientId)
  17. {
  18. _options = options ?? throw new ArgumentNullException(nameof(options));
  19. _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
  20. }
  21. public Action<string, TopicFilter> TopicSubscribedCallback { get; set; }
  22. public Action<string, string> TopicUnsubscribedCallback { get; set; }
  23. public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket)
  24. {
  25. if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));
  26. var result = new MqttClientSubscribeResult
  27. {
  28. ResponsePacket = new MqttSubAckPacket
  29. {
  30. PacketIdentifier = subscribePacket.PacketIdentifier
  31. },
  32. CloseConnection = false
  33. };
  34. await _semaphore.WaitAsync().ConfigureAwait(false);
  35. try
  36. {
  37. foreach (var topicFilter in subscribePacket.TopicFilters)
  38. {
  39. var interceptorContext = InterceptSubscribe(topicFilter);
  40. if (!interceptorContext.AcceptSubscription)
  41. {
  42. result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure);
  43. }
  44. else
  45. {
  46. result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel));
  47. }
  48. if (interceptorContext.CloseConnection)
  49. {
  50. result.CloseConnection = true;
  51. }
  52. if (interceptorContext.AcceptSubscription)
  53. {
  54. _subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
  55. TopicSubscribedCallback?.Invoke(_clientId, topicFilter);
  56. }
  57. }
  58. }
  59. finally
  60. {
  61. _semaphore.Release();
  62. }
  63. return result;
  64. }
  65. public async Task<MqttUnsubAckPacket> UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket)
  66. {
  67. if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket));
  68. await _semaphore.WaitAsync().ConfigureAwait(false);
  69. try
  70. {
  71. foreach (var topicFilter in unsubscribePacket.TopicFilters)
  72. {
  73. _subscriptions.Remove(topicFilter);
  74. TopicUnsubscribedCallback?.Invoke(_clientId, topicFilter);
  75. }
  76. }
  77. finally
  78. {
  79. _semaphore.Release();
  80. }
  81. return new MqttUnsubAckPacket
  82. {
  83. PacketIdentifier = unsubscribePacket.PacketIdentifier
  84. };
  85. }
  86. public async Task<CheckSubscriptionsResult> CheckSubscriptionsAsync(MqttApplicationMessage applicationMessage)
  87. {
  88. if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));
  89. await _semaphore.WaitAsync().ConfigureAwait(false);
  90. try
  91. {
  92. var qosLevels = new HashSet<MqttQualityOfServiceLevel>();
  93. foreach (var subscription in _subscriptions)
  94. {
  95. if (!MqttTopicFilterComparer.IsMatch(applicationMessage.Topic, subscription.Key))
  96. {
  97. continue;
  98. }
  99. qosLevels.Add(subscription.Value);
  100. }
  101. if (qosLevels.Count == 0)
  102. {
  103. return new CheckSubscriptionsResult
  104. {
  105. IsSubscribed = false
  106. };
  107. }
  108. return CreateSubscriptionResult(applicationMessage, qosLevels);
  109. }
  110. finally
  111. {
  112. _semaphore.Release();
  113. }
  114. }
  115. public void Dispose()
  116. {
  117. _semaphore?.Dispose();
  118. }
  119. private static MqttSubscribeReturnCode ConvertToMaximumQoS(MqttQualityOfServiceLevel qualityOfServiceLevel)
  120. {
  121. switch (qualityOfServiceLevel)
  122. {
  123. case MqttQualityOfServiceLevel.AtMostOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS0;
  124. case MqttQualityOfServiceLevel.AtLeastOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS1;
  125. case MqttQualityOfServiceLevel.ExactlyOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS2;
  126. default: return MqttSubscribeReturnCode.Failure;
  127. }
  128. }
  129. private MqttSubscriptionInterceptorContext InterceptSubscribe(TopicFilter topicFilter)
  130. {
  131. var interceptorContext = new MqttSubscriptionInterceptorContext(_clientId, topicFilter);
  132. _options.SubscriptionInterceptor?.Invoke(interceptorContext);
  133. return interceptorContext;
  134. }
  135. private static CheckSubscriptionsResult CreateSubscriptionResult(MqttApplicationMessage applicationMessage, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels)
  136. {
  137. MqttQualityOfServiceLevel effectiveQoS;
  138. if (subscribedQoSLevels.Contains(applicationMessage.QualityOfServiceLevel))
  139. {
  140. effectiveQoS = applicationMessage.QualityOfServiceLevel;
  141. }
  142. else if (subscribedQoSLevels.Count == 1)
  143. {
  144. effectiveQoS = subscribedQoSLevels.First();
  145. }
  146. else
  147. {
  148. effectiveQoS = subscribedQoSLevels.Max();
  149. }
  150. return new CheckSubscriptionsResult
  151. {
  152. IsSubscribed = true,
  153. QualityOfServiceLevel = effectiveQoS
  154. };
  155. }
  156. }
  157. }