You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

315 lines
12 KiB

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using MQTTnet.Adapter;
  7. using MQTTnet.Diagnostics;
  8. using MQTTnet.Exceptions;
  9. using MQTTnet.Packets;
  10. using MQTTnet.Protocol;
  11. using MQTTnet.Serializer;
  12. namespace MQTTnet.Server
  13. {
  14. public sealed class MqttClientSessionsManager : IDisposable
  15. {
  16. private readonly Dictionary<string, MqttClientSession> _sessions = new Dictionary<string, MqttClientSession>();
  17. private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
  18. private readonly IMqttServerOptions _options;
  19. private readonly MqttRetainedMessagesManager _retainedMessagesManager;
  20. private readonly IMqttNetLogger _logger;
  21. public MqttClientSessionsManager(IMqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetLogger logger)
  22. {
  23. _logger = logger ?? throw new ArgumentNullException(nameof(logger));
  24. _options = options ?? throw new ArgumentNullException(nameof(options));
  25. _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager));
  26. }
  27. public Action<ConnectedMqttClient> ClientConnectedCallback { get; set; }
  28. public Action<ConnectedMqttClient> ClientDisconnectedCallback { get; set; }
  29. public Action<string, TopicFilter> ClientSubscribedTopicCallback { get; set; }
  30. public Action<string, string> ClientUnsubscribedTopicCallback { get; set; }
  31. public Action<string, MqttApplicationMessage> ApplicationMessageReceivedCallback { get; set; }
  32. public async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
  33. {
  34. var clientId = string.Empty;
  35. MqttClientSession clientSession = null;
  36. try
  37. {
  38. if (!(await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false) is MqttConnectPacket connectPacket))
  39. {
  40. throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1].");
  41. }
  42. clientId = connectPacket.ClientId;
  43. // Switch to the required protocol version before sending any response.
  44. clientAdapter.PacketSerializer.ProtocolVersion = connectPacket.ProtocolVersion;
  45. var connectReturnCode = ValidateConnection(connectPacket);
  46. if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
  47. {
  48. await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttConnAckPacket
  49. {
  50. ConnectReturnCode = connectReturnCode
  51. }).ConfigureAwait(false);
  52. return;
  53. }
  54. var result = await GetOrCreateClientSessionAsync(connectPacket).ConfigureAwait(false);
  55. clientSession = result.Session;
  56. await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, new MqttConnAckPacket
  57. {
  58. ConnectReturnCode = connectReturnCode,
  59. IsSessionPresent = result.IsExistingSession
  60. }).ConfigureAwait(false);
  61. ClientConnectedCallback?.Invoke(new ConnectedMqttClient
  62. {
  63. ClientId = clientId,
  64. ProtocolVersion = clientAdapter.PacketSerializer.ProtocolVersion
  65. });
  66. await clientSession.RunAsync(connectPacket, clientAdapter).ConfigureAwait(false);
  67. }
  68. catch (Exception exception)
  69. {
  70. _logger.Error<MqttClientSessionsManager>(exception, exception.Message);
  71. }
  72. finally
  73. {
  74. try
  75. {
  76. await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false);
  77. }
  78. catch (Exception)
  79. {
  80. // ignored
  81. }
  82. ClientDisconnectedCallback?.Invoke(new ConnectedMqttClient
  83. {
  84. ClientId = clientId,
  85. ProtocolVersion = clientAdapter.PacketSerializer.ProtocolVersion,
  86. PendingApplicationMessages = clientSession?.PendingMessagesQueue.Count ?? 0
  87. });
  88. }
  89. }
  90. public async Task StopAsync()
  91. {
  92. await _semaphore.WaitAsync().ConfigureAwait(false);
  93. try
  94. {
  95. foreach (var session in _sessions)
  96. {
  97. await session.Value.StopAsync().ConfigureAwait(false);
  98. }
  99. _sessions.Clear();
  100. }
  101. finally
  102. {
  103. _semaphore.Release();
  104. }
  105. }
  106. public async Task<IList<ConnectedMqttClient>> GetConnectedClientsAsync()
  107. {
  108. await _semaphore.WaitAsync().ConfigureAwait(false);
  109. try
  110. {
  111. return _sessions.Where(s => s.Value.IsConnected).Select(s => new ConnectedMqttClient
  112. {
  113. ClientId = s.Value.ClientId,
  114. ProtocolVersion = s.Value.ProtocolVersion ?? MqttProtocolVersion.V311,
  115. LastPacketReceived = s.Value.KeepAliveMonitor.LastPacketReceived,
  116. LastNonKeepAlivePacketReceived = s.Value.KeepAliveMonitor.LastNonKeepAlivePacketReceived,
  117. PendingApplicationMessages = s.Value.PendingMessagesQueue.Count
  118. }).ToList();
  119. }
  120. finally
  121. {
  122. _semaphore.Release();
  123. }
  124. }
  125. public async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
  126. {
  127. try
  128. {
  129. applicationMessage = InterceptApplicationMessage(senderClientSession, applicationMessage);
  130. if (applicationMessage == null)
  131. {
  132. return;
  133. }
  134. if (applicationMessage.Retain)
  135. {
  136. await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false);
  137. }
  138. ApplicationMessageReceivedCallback?.Invoke(senderClientSession?.ClientId, applicationMessage);
  139. }
  140. catch (Exception exception)
  141. {
  142. _logger.Error<MqttClientSessionsManager>(exception, "Error while processing application message");
  143. }
  144. await _semaphore.WaitAsync().ConfigureAwait(false);
  145. try
  146. {
  147. foreach (var clientSession in _sessions.Values)
  148. {
  149. await clientSession.EnqueueApplicationMessageAsync(applicationMessage);
  150. }
  151. }
  152. finally
  153. {
  154. _semaphore.Release();
  155. }
  156. }
  157. public async Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
  158. {
  159. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  160. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  161. await _semaphore.WaitAsync().ConfigureAwait(false);
  162. try
  163. {
  164. if (!_sessions.TryGetValue(clientId, out var session))
  165. {
  166. throw new InvalidOperationException($"Client session {clientId} is unknown.");
  167. }
  168. await session.SubscribeAsync(topicFilters);
  169. }
  170. finally
  171. {
  172. _semaphore.Release();
  173. }
  174. }
  175. public async Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
  176. {
  177. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  178. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  179. await _semaphore.WaitAsync().ConfigureAwait(false);
  180. try
  181. {
  182. if (!_sessions.TryGetValue(clientId, out var session))
  183. {
  184. throw new InvalidOperationException($"Client session {clientId} is unknown.");
  185. }
  186. await session.UnsubscribeAsync(topicFilters);
  187. }
  188. finally
  189. {
  190. _semaphore.Release();
  191. }
  192. }
  193. private MqttApplicationMessage InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
  194. {
  195. var interceptor = _options.ApplicationMessageInterceptor;
  196. if (interceptor == null)
  197. {
  198. return applicationMessage;
  199. }
  200. var interceptorContext = new MqttApplicationMessageInterceptorContext(
  201. senderClientSession.ClientId,
  202. applicationMessage);
  203. interceptor(interceptorContext);
  204. return interceptorContext.ApplicationMessage;
  205. }
  206. private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket)
  207. {
  208. if (_options.ConnectionValidator == null)
  209. {
  210. return MqttConnectReturnCode.ConnectionAccepted;
  211. }
  212. var context = new MqttConnectionValidatorContext(
  213. connectPacket.ClientId,
  214. connectPacket.Username,
  215. connectPacket.Password,
  216. connectPacket.WillMessage);
  217. _options.ConnectionValidator(context);
  218. return context.ReturnCode;
  219. }
  220. private async Task<GetOrCreateClientSessionResult> GetOrCreateClientSessionAsync(MqttConnectPacket connectPacket)
  221. {
  222. await _semaphore.WaitAsync().ConfigureAwait(false);
  223. try
  224. {
  225. var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession);
  226. if (isSessionPresent)
  227. {
  228. if (connectPacket.CleanSession)
  229. {
  230. _sessions.Remove(connectPacket.ClientId);
  231. await clientSession.StopAsync().ConfigureAwait(false);
  232. clientSession.Dispose();
  233. clientSession = null;
  234. _logger.Trace<MqttClientSessionsManager>("Stopped existing session of client '{0}'.", connectPacket.ClientId);
  235. }
  236. else
  237. {
  238. _logger.Trace<MqttClientSessionsManager>("Reusing existing session of client '{0}'.", connectPacket.ClientId);
  239. }
  240. }
  241. var isExistingSession = true;
  242. if (clientSession == null)
  243. {
  244. isExistingSession = false;
  245. clientSession = new MqttClientSession(connectPacket.ClientId, _options, _retainedMessagesManager, _logger)
  246. {
  247. ApplicationMessageReceivedCallback = DispatchApplicationMessageAsync
  248. };
  249. clientSession.SubscriptionsManager.TopicSubscribedCallback = ClientSubscribedTopicCallback;
  250. clientSession.SubscriptionsManager.TopicUnsubscribedCallback = ClientUnsubscribedTopicCallback;
  251. _sessions[connectPacket.ClientId] = clientSession;
  252. _logger.Trace<MqttClientSessionsManager>("Created a new session for client '{0}'.", connectPacket.ClientId);
  253. }
  254. return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession };
  255. }
  256. finally
  257. {
  258. _semaphore.Release();
  259. }
  260. }
  261. public void Dispose()
  262. {
  263. ClientConnectedCallback = null;
  264. ClientDisconnectedCallback = null;
  265. ClientSubscribedTopicCallback = null;
  266. ClientUnsubscribedTopicCallback = null;
  267. ApplicationMessageReceivedCallback = null;
  268. _semaphore?.Dispose();
  269. }
  270. }
  271. }