You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

395 lines
15 KiB

  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using MQTTnet.Adapter;
  7. using MQTTnet.Diagnostics;
  8. using MQTTnet.Internal;
  9. using MQTTnet.Packets;
  10. using MQTTnet.Protocol;
  11. using MQTTnet.Server.Status;
  12. namespace MQTTnet.Server
  13. {
  14. public class MqttClientSessionsManager : IDisposable
  15. {
  16. private readonly AsyncQueue<MqttEnqueuedApplicationMessage> _messageQueue = new AsyncQueue<MqttEnqueuedApplicationMessage>();
  17. private readonly SemaphoreSlim _createConnectionGate = new SemaphoreSlim(1, 1);
  18. private readonly ConcurrentDictionary<string, MqttClientConnection> _connections = new ConcurrentDictionary<string, MqttClientConnection>();
  19. private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>();
  20. private readonly CancellationToken _cancellationToken;
  21. private readonly MqttServerEventDispatcher _eventDispatcher;
  22. private readonly MqttRetainedMessagesManager _retainedMessagesManager;
  23. private readonly IMqttServerOptions _options;
  24. private readonly IMqttNetChildLogger _logger;
  25. public MqttClientSessionsManager(
  26. IMqttServerOptions options,
  27. MqttRetainedMessagesManager retainedMessagesManager,
  28. CancellationToken cancellationToken,
  29. MqttServerEventDispatcher eventDispatcher,
  30. IMqttNetChildLogger logger)
  31. {
  32. _cancellationToken = cancellationToken;
  33. if (logger == null) throw new ArgumentNullException(nameof(logger));
  34. _logger = logger.CreateChildLogger(nameof(MqttClientSessionsManager));
  35. _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher));
  36. _options = options ?? throw new ArgumentNullException(nameof(options));
  37. _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager));
  38. }
  39. public void Start()
  40. {
  41. Task.Run(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken);
  42. }
  43. public async Task StopAsync()
  44. {
  45. foreach (var connection in _connections.Values)
  46. {
  47. await connection.StopAsync().ConfigureAwait(false);
  48. }
  49. }
  50. public Task HandleConnectionAsync(IMqttChannelAdapter clientAdapter)
  51. {
  52. return HandleConnectionAsync(clientAdapter, _cancellationToken);
  53. }
  54. public Task<IList<IMqttClientStatus>> GetClientStatusAsync()
  55. {
  56. var result = new List<IMqttClientStatus>();
  57. foreach (var connection in _connections.Values)
  58. {
  59. var clientStatus = new MqttClientStatus(connection);
  60. connection.FillStatus(clientStatus);
  61. var sessionStatus = new MqttSessionStatus(connection.Session, this);
  62. connection.Session.FillStatus(sessionStatus);
  63. clientStatus.Session = sessionStatus;
  64. result.Add(clientStatus);
  65. }
  66. return Task.FromResult((IList<IMqttClientStatus>)result);
  67. }
  68. public Task<IList<IMqttSessionStatus>> GetSessionStatusAsync()
  69. {
  70. var result = new List<IMqttSessionStatus>();
  71. foreach (var session in _sessions.Values)
  72. {
  73. var sessionStatus = new MqttSessionStatus(session, this);
  74. session.FillStatus(sessionStatus);
  75. result.Add(sessionStatus);
  76. }
  77. return Task.FromResult((IList<IMqttSessionStatus>)result);
  78. }
  79. public void DispatchApplicationMessage(MqttApplicationMessage applicationMessage, MqttClientConnection sender)
  80. {
  81. if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));
  82. _messageQueue.Enqueue(new MqttEnqueuedApplicationMessage(applicationMessage, sender));
  83. }
  84. public Task SubscribeAsync(string clientId, ICollection<TopicFilter> topicFilters)
  85. {
  86. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  87. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  88. if (!_sessions.TryGetValue(clientId, out var session))
  89. {
  90. throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
  91. }
  92. return session.SubscribeAsync(topicFilters, _retainedMessagesManager);
  93. }
  94. public Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters)
  95. {
  96. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  97. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  98. if (!_sessions.TryGetValue(clientId, out var session))
  99. {
  100. throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
  101. }
  102. return session.UnsubscribeAsync(topicFilters);
  103. }
  104. public async Task DeleteSessionAsync(string clientId)
  105. {
  106. if (_connections.TryGetValue(clientId, out var connection))
  107. {
  108. await connection.StopAsync().ConfigureAwait(false);
  109. }
  110. if (_sessions.TryRemove(clientId, out var session))
  111. {
  112. }
  113. _logger.Verbose("Session for client '{0}' deleted.", clientId);
  114. }
  115. public void Dispose()
  116. {
  117. _messageQueue?.Dispose();
  118. }
  119. private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken)
  120. {
  121. while (!cancellationToken.IsCancellationRequested)
  122. {
  123. try
  124. {
  125. await TryProcessNextQueuedApplicationMessageAsync(cancellationToken).ConfigureAwait(false);
  126. }
  127. catch (OperationCanceledException)
  128. {
  129. }
  130. catch (Exception exception)
  131. {
  132. _logger.Error(exception, "Unhandled exception while processing queued application messages.");
  133. }
  134. }
  135. }
  136. private async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken)
  137. {
  138. try
  139. {
  140. if (cancellationToken.IsCancellationRequested)
  141. {
  142. return;
  143. }
  144. var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false);
  145. var queuedApplicationMessage = dequeueResult.Item;
  146. var sender = queuedApplicationMessage.Sender;
  147. var applicationMessage = queuedApplicationMessage.ApplicationMessage;
  148. var interceptorContext = await InterceptApplicationMessageAsync(sender, applicationMessage).ConfigureAwait(false);
  149. if (interceptorContext != null)
  150. {
  151. if (interceptorContext.CloseConnection)
  152. {
  153. await queuedApplicationMessage.Sender.StopAsync().ConfigureAwait(false);
  154. }
  155. if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish)
  156. {
  157. return;
  158. }
  159. applicationMessage = interceptorContext.ApplicationMessage;
  160. }
  161. await _eventDispatcher.HandleApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
  162. if (applicationMessage.Retain)
  163. {
  164. await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
  165. }
  166. foreach (var clientSession in _sessions.Values)
  167. {
  168. clientSession.EnqueueApplicationMessage(
  169. queuedApplicationMessage.ApplicationMessage,
  170. sender?.ClientId,
  171. false);
  172. }
  173. }
  174. catch (OperationCanceledException)
  175. {
  176. }
  177. catch (Exception exception)
  178. {
  179. _logger.Error(exception, "Unhandled exception while processing next queued application message.");
  180. }
  181. }
  182. private async Task HandleConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
  183. {
  184. var disconnectType = MqttClientDisconnectType.NotClean;
  185. var clientId = string.Empty;
  186. try
  187. {
  188. var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
  189. if (!(firstPacket is MqttConnectPacket connectPacket))
  190. {
  191. _logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", channelAdapter.Endpoint);
  192. return;
  193. }
  194. clientId = connectPacket.ClientId;
  195. var validatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);
  196. if (validatorContext.ReturnCode != MqttConnectReturnCode.ConnectionAccepted)
  197. {
  198. // TODO: Move to channel adapter data converter.
  199. // Send failure response here without preparing a session. The result for a successful connect
  200. // will be sent from the session itself.
  201. await channelAdapter.SendPacketAsync(
  202. new MqttConnAckPacket
  203. {
  204. ReturnCode = validatorContext.ReturnCode,
  205. ReasonCode = MqttConnectReasonCode.NotAuthorized
  206. },
  207. _options.DefaultCommunicationTimeout,
  208. cancellationToken).ConfigureAwait(false);
  209. return;
  210. }
  211. var connection = await CreateConnectionAsync(channelAdapter, connectPacket).ConfigureAwait(false);
  212. await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false);
  213. disconnectType = await connection.RunAsync().ConfigureAwait(false);
  214. }
  215. catch (OperationCanceledException)
  216. {
  217. }
  218. catch (Exception exception)
  219. {
  220. _logger.Error(exception, exception.Message);
  221. }
  222. finally
  223. {
  224. _connections.TryRemove(clientId, out _);
  225. ////connection?.ReferenceCounter.Decrement();
  226. ////if (connection?.ReferenceCounter.HasReferences == true)
  227. ////{
  228. //// disconnectType = MqttClientDisconnectType.Takeover;
  229. ////}
  230. ////else
  231. {
  232. if (!_options.EnablePersistentSessions)
  233. {
  234. await DeleteSessionAsync(clientId).ConfigureAwait(false);
  235. }
  236. }
  237. await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false);
  238. await _eventDispatcher.HandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
  239. }
  240. }
  241. private async Task<MqttConnectionValidatorContext> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter)
  242. {
  243. var context = new MqttConnectionValidatorContext(
  244. connectPacket.ClientId,
  245. connectPacket.Username,
  246. connectPacket.Password,
  247. connectPacket.WillMessage,
  248. clientAdapter.Endpoint,
  249. clientAdapter.IsSecureConnection);
  250. var connectionValidator = _options.ConnectionValidator;
  251. if (connectionValidator == null)
  252. {
  253. context.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
  254. return context;
  255. }
  256. await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false);
  257. return context;
  258. }
  259. private async Task<MqttClientConnection> CreateConnectionAsync(IMqttChannelAdapter channelAdapter, MqttConnectPacket connectPacket)
  260. {
  261. await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false);
  262. try
  263. {
  264. var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session);
  265. var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection);
  266. if (isConnectionPresent)
  267. {
  268. await existingConnection.StopAsync().ConfigureAwait(false);
  269. }
  270. if (isSessionPresent)
  271. {
  272. if (connectPacket.CleanSession)
  273. {
  274. // TODO: Check if required.
  275. //session.Dispose();
  276. session = null;
  277. _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId);
  278. }
  279. else
  280. {
  281. _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId);
  282. }
  283. }
  284. if (session == null)
  285. {
  286. session = new MqttClientSession(connectPacket.ClientId, _eventDispatcher, _options);
  287. _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId);
  288. }
  289. var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, _logger);
  290. _connections[connection.ClientId] = connection;
  291. _sessions[session.ClientId] = session;
  292. return connection;
  293. }
  294. finally
  295. {
  296. _createConnectionGate.Release();
  297. }
  298. }
  299. private async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection sender, MqttApplicationMessage applicationMessage)
  300. {
  301. var interceptor = _options.ApplicationMessageInterceptor;
  302. if (interceptor == null)
  303. {
  304. return null;
  305. }
  306. var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage);
  307. await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
  308. return interceptorContext;
  309. }
  310. private async Task TryCleanupChannelAsync(IMqttChannelAdapter channelAdapter)
  311. {
  312. try
  313. {
  314. await channelAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
  315. }
  316. catch (Exception exception)
  317. {
  318. _logger.Error(exception, "Error while disconnecting client channel.");
  319. }
  320. finally
  321. {
  322. channelAdapter.Dispose();
  323. }
  324. }
  325. }
  326. }