You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

289 lines
11 KiB

  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using MQTTnet.Adapter;
  7. using MQTTnet.Diagnostics;
  8. using MQTTnet.Exceptions;
  9. using MQTTnet.Internal;
  10. using MQTTnet.Packets;
  11. using MQTTnet.Protocol;
  12. namespace MQTTnet.Server
  13. {
  14. public class MqttClientSessionsManager : IDisposable
  15. {
  16. private readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>();
  17. private readonly AsyncLock _sessionPreparationLock = new AsyncLock();
  18. private readonly MqttRetainedMessagesManager _retainedMessagesManager;
  19. private readonly IMqttServerOptions _options;
  20. private readonly IMqttNetChildLogger _logger;
  21. public MqttClientSessionsManager(IMqttServerOptions options, MqttServer server, MqttRetainedMessagesManager retainedMessagesManager, IMqttNetChildLogger logger)
  22. {
  23. if (logger == null) throw new ArgumentNullException(nameof(logger));
  24. _logger = logger.CreateChildLogger(nameof(MqttClientSessionsManager));
  25. _options = options ?? throw new ArgumentNullException(nameof(options));
  26. Server = server ?? throw new ArgumentNullException(nameof(server));
  27. _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager));
  28. }
  29. public MqttServer Server { get; }
  30. public async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken)
  31. {
  32. var clientId = string.Empty;
  33. var wasCleanDisconnect = false;
  34. try
  35. {
  36. var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
  37. if (firstPacket == null)
  38. {
  39. return;
  40. }
  41. if (!(firstPacket is MqttConnectPacket connectPacket))
  42. {
  43. throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1].");
  44. }
  45. clientId = connectPacket.ClientId;
  46. // Switch to the required protocol version before sending any response.
  47. clientAdapter.PacketSerializer.ProtocolVersion = connectPacket.ProtocolVersion;
  48. var connectReturnCode = ValidateConnection(connectPacket);
  49. if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
  50. {
  51. await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[]
  52. {
  53. new MqttConnAckPacket
  54. {
  55. ConnectReturnCode = connectReturnCode
  56. }
  57. }, cancellationToken).ConfigureAwait(false);
  58. return;
  59. }
  60. var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false);
  61. var clientSession = result.Session;
  62. await clientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[]
  63. {
  64. new MqttConnAckPacket
  65. {
  66. ConnectReturnCode = connectReturnCode,
  67. IsSessionPresent = result.IsExistingSession
  68. }
  69. }, cancellationToken).ConfigureAwait(false);
  70. Server.OnClientConnected(clientId);
  71. wasCleanDisconnect = await clientSession.RunAsync(connectPacket, clientAdapter).ConfigureAwait(false);
  72. }
  73. catch (OperationCanceledException)
  74. {
  75. }
  76. catch (Exception exception)
  77. {
  78. _logger.Error(exception, exception.Message);
  79. }
  80. finally
  81. {
  82. try
  83. {
  84. await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
  85. clientAdapter.Dispose();
  86. }
  87. catch (Exception exception)
  88. {
  89. _logger.Error(exception, exception.Message);
  90. }
  91. if (!_options.EnablePersistentSessions)
  92. {
  93. DeleteSession(clientId);
  94. }
  95. Server.OnClientDisconnected(clientId, wasCleanDisconnect);
  96. }
  97. }
  98. public Task StopAsync()
  99. {
  100. foreach (var session in _sessions)
  101. {
  102. session.Value.Stop(MqttClientDisconnectType.NotClean);
  103. }
  104. _sessions.Clear();
  105. return Task.FromResult(0);
  106. }
  107. public Task<IList<IMqttClientSessionStatus>> GetClientStatusAsync()
  108. {
  109. var result = new List<IMqttClientSessionStatus>();
  110. foreach (var session in _sessions)
  111. {
  112. var status = new MqttClientSessionStatus(this, session.Value);
  113. session.Value.FillStatus(status);
  114. result.Add(status);
  115. }
  116. return Task.FromResult((IList<IMqttClientSessionStatus>)result);
  117. }
  118. public void StartDispatchApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
  119. {
  120. Task.Run(() => DispatchApplicationMessageAsync(senderClientSession, applicationMessage));
  121. }
  122. public Task SubscribeAsync(string clientId, IList<TopicFilter> topicFilters)
  123. {
  124. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  125. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  126. if (!_sessions.TryGetValue(clientId, out var session))
  127. {
  128. throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
  129. }
  130. return session.SubscribeAsync(topicFilters);
  131. }
  132. public Task UnsubscribeAsync(string clientId, IList<string> topicFilters)
  133. {
  134. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  135. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  136. if (!_sessions.TryGetValue(clientId, out var session))
  137. {
  138. throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
  139. }
  140. return session.UnsubscribeAsync(topicFilters);
  141. }
  142. public void DeleteSession(string clientId)
  143. {
  144. _sessions.TryRemove(clientId, out _);
  145. _logger.Verbose("Session for client '{0}' deleted.", clientId);
  146. }
  147. public void Dispose()
  148. {
  149. _sessionPreparationLock?.Dispose();
  150. }
  151. private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket)
  152. {
  153. if (_options.ConnectionValidator == null)
  154. {
  155. return MqttConnectReturnCode.ConnectionAccepted;
  156. }
  157. var context = new MqttConnectionValidatorContext(
  158. connectPacket.ClientId,
  159. connectPacket.Username,
  160. connectPacket.Password,
  161. connectPacket.WillMessage);
  162. _options.ConnectionValidator(context);
  163. return context.ReturnCode;
  164. }
  165. private async Task<GetOrCreateClientSessionResult> PrepareClientSessionAsync(MqttConnectPacket connectPacket)
  166. {
  167. using (await _sessionPreparationLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
  168. {
  169. var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession);
  170. if (isSessionPresent)
  171. {
  172. if (connectPacket.CleanSession)
  173. {
  174. _sessions.TryRemove(connectPacket.ClientId, out _);
  175. clientSession.Stop(MqttClientDisconnectType.Clean);
  176. clientSession.Dispose();
  177. clientSession = null;
  178. _logger.Verbose("Stopped existing session of client '{0}'.", connectPacket.ClientId);
  179. }
  180. else
  181. {
  182. _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId);
  183. }
  184. }
  185. var isExistingSession = true;
  186. if (clientSession == null)
  187. {
  188. isExistingSession = false;
  189. clientSession = new MqttClientSession(connectPacket.ClientId, _options, this, _retainedMessagesManager, _logger);
  190. _sessions[connectPacket.ClientId] = clientSession;
  191. _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId);
  192. }
  193. return new GetOrCreateClientSessionResult { IsExistingSession = isExistingSession, Session = clientSession };
  194. }
  195. }
  196. private async Task DispatchApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
  197. {
  198. try
  199. {
  200. var interceptorContext = InterceptApplicationMessage(senderClientSession, applicationMessage);
  201. if (interceptorContext.CloseConnection)
  202. {
  203. senderClientSession.Stop(MqttClientDisconnectType.NotClean);
  204. }
  205. if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish)
  206. {
  207. return;
  208. }
  209. Server.OnApplicationMessageReceived(senderClientSession?.ClientId, applicationMessage);
  210. if (applicationMessage.Retain)
  211. {
  212. await _retainedMessagesManager.HandleMessageAsync(senderClientSession?.ClientId, applicationMessage).ConfigureAwait(false);
  213. }
  214. foreach (var clientSession in _sessions.Values)
  215. {
  216. clientSession.EnqueueApplicationMessage(senderClientSession, applicationMessage);
  217. }
  218. }
  219. catch (Exception exception)
  220. {
  221. _logger.Error(exception, "Error while processing application message");
  222. }
  223. }
  224. private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage)
  225. {
  226. var interceptorContext = new MqttApplicationMessageInterceptorContext(
  227. senderClientSession?.ClientId,
  228. applicationMessage);
  229. var interceptor = _options.ApplicationMessageInterceptor;
  230. if (interceptor == null)
  231. {
  232. return interceptorContext;
  233. }
  234. interceptor(interceptorContext);
  235. return interceptorContext;
  236. }
  237. }
  238. }