You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TestEnvironment.cs 8.2 KiB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using MQTTnet.Client;
  3. using MQTTnet.Client.Options;
  4. using MQTTnet.Diagnostics;
  5. using MQTTnet.Server;
  6. using System;
  7. using System.Collections.Generic;
  8. using System.Diagnostics;
  9. using System.Linq;
  10. using System.Threading;
  11. using System.Threading.Tasks;
  12. using MQTTnet.Extensions.Rpc;
  13. using MQTTnet.Extensions.Rpc.Options;
  14. using MQTTnet.LowLevelClient;
  15. namespace MQTTnet.Tests.Mockups
  16. {
  17. public sealed class TestEnvironment : IDisposable
  18. {
  19. readonly MqttFactory _mqttFactory = new MqttFactory();
  20. readonly List<ILowLevelMqttClient> _lowLevelClients = new List<ILowLevelMqttClient>();
  21. readonly List<IMqttClient> _clients = new List<IMqttClient>();
  22. readonly List<string> _serverErrors = new List<string>();
  23. readonly List<string> _clientErrors = new List<string>();
  24. readonly List<Exception> _exceptions = new List<Exception>();
  25. public IMqttServer Server { get; private set; }
  26. public bool IgnoreClientLogErrors { get; set; }
  27. public bool IgnoreServerLogErrors { get; set; }
  28. public int ServerPort { get; set; } = 1888;
  29. public MqttNetLogger ServerLogger { get; } = new MqttNetLogger("server");
  30. public MqttNetLogger ClientLogger { get; } = new MqttNetLogger("client");
  31. public TestContext TestContext { get; }
  32. public TestEnvironment() : this(null)
  33. {
  34. }
  35. public TestEnvironment(TestContext testContext)
  36. {
  37. TestContext = testContext;
  38. ServerLogger.LogMessagePublished += (s, e) =>
  39. {
  40. if (Debugger.IsAttached)
  41. {
  42. Debug.WriteLine(e.LogMessage.ToString());
  43. }
  44. if (e.LogMessage.Level == MqttNetLogLevel.Error)
  45. {
  46. lock (_serverErrors)
  47. {
  48. _serverErrors.Add(e.LogMessage.ToString());
  49. }
  50. }
  51. };
  52. ClientLogger.LogMessagePublished += (s, e) =>
  53. {
  54. if (Debugger.IsAttached)
  55. {
  56. Debug.WriteLine(e.LogMessage.ToString());
  57. }
  58. if (e.LogMessage.Level == MqttNetLogLevel.Error)
  59. {
  60. lock (_clientErrors)
  61. {
  62. _clientErrors.Add(e.LogMessage.ToString());
  63. }
  64. }
  65. };
  66. }
  67. public IMqttClient CreateClient()
  68. {
  69. lock (_clients)
  70. {
  71. var client = _mqttFactory.CreateMqttClient(ClientLogger);
  72. _clients.Add(client);
  73. return new TestClientWrapper(client, TestContext);
  74. }
  75. }
  76. public Task<IMqttClient> ConnectClientAsync()
  77. {
  78. return ConnectClientAsync(new MqttClientOptionsBuilder());
  79. }
  80. public async Task<IMqttClient> ConnectClientAsync(Action<MqttClientOptionsBuilder> optionsBuilder)
  81. {
  82. if (optionsBuilder == null) throw new ArgumentNullException(nameof(optionsBuilder));
  83. var options = new MqttClientOptionsBuilder();
  84. options = options.WithTcpServer("localhost", ServerPort);
  85. optionsBuilder.Invoke(options);
  86. var client = CreateClient();
  87. await client.ConnectAsync(options.Build()).ConfigureAwait(false);
  88. return client;
  89. }
  90. public async Task<IMqttClient> ConnectClientAsync(MqttClientOptionsBuilder options)
  91. {
  92. if (options == null) throw new ArgumentNullException(nameof(options));
  93. options = options.WithTcpServer("localhost", ServerPort);
  94. var client = CreateClient();
  95. await client.ConnectAsync(options.Build()).ConfigureAwait(false);
  96. return client;
  97. }
  98. public async Task<IMqttClient> ConnectClientAsync(IMqttClientOptions options)
  99. {
  100. if (options == null) throw new ArgumentNullException(nameof(options));
  101. var client = CreateClient();
  102. await client.ConnectAsync(options).ConfigureAwait(false);
  103. return client;
  104. }
  105. public ILowLevelMqttClient CreateLowLevelClient()
  106. {
  107. lock (_clients)
  108. {
  109. var client = _mqttFactory.CreateLowLevelMqttClient(ClientLogger);
  110. _lowLevelClients.Add(client);
  111. return client;
  112. }
  113. }
  114. public Task<ILowLevelMqttClient> ConnectLowLevelClientAsync()
  115. {
  116. return ConnectLowLevelClientAsync(o => { });
  117. }
  118. public async Task<ILowLevelMqttClient> ConnectLowLevelClientAsync(Action<MqttClientOptionsBuilder> optionsBuilder)
  119. {
  120. if (optionsBuilder == null) throw new ArgumentNullException(nameof(optionsBuilder));
  121. var options = new MqttClientOptionsBuilder();
  122. options = options.WithTcpServer("127.0.0.1", ServerPort);
  123. optionsBuilder.Invoke(options);
  124. var client = CreateLowLevelClient();
  125. await client.ConnectAsync(options.Build(), CancellationToken.None).ConfigureAwait(false);
  126. return client;
  127. }
  128. public async Task<IMqttRpcClient> ConnectRpcClientAsync(IMqttRpcClientOptions options)
  129. {
  130. return new MqttRpcClient(await ConnectClientAsync(), options);
  131. }
  132. public Task<IMqttServer> StartServerAsync()
  133. {
  134. return StartServerAsync(new MqttServerOptionsBuilder());
  135. }
  136. public async Task<IMqttServer> StartServerAsync(MqttServerOptionsBuilder options)
  137. {
  138. if (options == null) throw new ArgumentNullException(nameof(options));
  139. if (Server != null)
  140. {
  141. throw new InvalidOperationException("Server already started.");
  142. }
  143. Server = new TestServerWrapper(_mqttFactory.CreateMqttServer(ServerLogger), TestContext, this);
  144. options.WithDefaultEndpointPort(ServerPort);
  145. options.WithMaxPendingMessagesPerClient(int.MaxValue);
  146. await Server.StartAsync(options.Build()).ConfigureAwait(false);
  147. return Server;
  148. }
  149. public void ThrowIfLogErrors()
  150. {
  151. lock (_serverErrors)
  152. {
  153. if (!IgnoreServerLogErrors && _serverErrors.Count > 0)
  154. {
  155. throw new Exception($"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)}).");
  156. }
  157. }
  158. lock (_clientErrors)
  159. {
  160. if (!IgnoreClientLogErrors && _clientErrors.Count > 0)
  161. {
  162. throw new Exception($"Client(s) had {_clientErrors.Count} errors (${string.Join(Environment.NewLine, _clientErrors)}).");
  163. }
  164. }
  165. }
  166. public void TrackException(Exception exception)
  167. {
  168. lock (_exceptions)
  169. {
  170. _exceptions.Add(exception);
  171. }
  172. }
  173. public void Dispose()
  174. {
  175. foreach (var mqttClient in _clients)
  176. {
  177. try
  178. {
  179. mqttClient.DisconnectAsync().GetAwaiter().GetResult();
  180. }
  181. catch
  182. {
  183. // This can happen when the test already disconnected the client.
  184. }
  185. finally
  186. {
  187. mqttClient?.Dispose();
  188. }
  189. }
  190. foreach (var lowLevelMqttClient in _lowLevelClients)
  191. {
  192. lowLevelMqttClient.Dispose();
  193. }
  194. try
  195. {
  196. Server?.StopAsync().GetAwaiter().GetResult();
  197. }
  198. catch
  199. {
  200. // This can happen when the test already stopped the server.
  201. }
  202. finally
  203. {
  204. Server?.Dispose();
  205. }
  206. ThrowIfLogErrors();
  207. if (_exceptions.Any())
  208. {
  209. throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions));
  210. }
  211. }
  212. }
  213. }