You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

164 lines
5.1 KiB

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Threading.Tasks;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using MQTTnet.Client;
  7. using MQTTnet.Client.Options;
  8. using MQTTnet.Diagnostics;
  9. using MQTTnet.Internal;
  10. using MQTTnet.Server;
  11. namespace MQTTnet.Tests.Mockups
  12. {
  13. public class TestEnvironment : Disposable
  14. {
  15. private readonly MqttFactory _mqttFactory = new MqttFactory();
  16. private readonly List<IMqttClient> _clients = new List<IMqttClient>();
  17. private readonly IMqttNetLogger _serverLogger = new MqttNetLogger("server");
  18. private readonly IMqttNetLogger _clientLogger = new MqttNetLogger("client");
  19. private readonly List<string> _serverErrors = new List<string>();
  20. private readonly List<string> _clientErrors = new List<string>();
  21. private readonly List<Exception> _exceptions = new List<Exception>();
  22. public IMqttServer Server { get; private set; }
  23. public bool IgnoreClientLogErrors { get; set; }
  24. public bool IgnoreServerLogErrors { get; set; }
  25. public int ServerPort { get; set; } = 1888;
  26. public IMqttNetLogger ServerLogger => _serverLogger;
  27. public IMqttNetLogger ClientLogger => _clientLogger;
  28. public TestContext TestContext { get; }
  29. public TestEnvironment(TestContext testContext)
  30. {
  31. _serverLogger.LogMessagePublished += (s, e) =>
  32. {
  33. if (e.TraceMessage.Level == MqttNetLogLevel.Error)
  34. {
  35. lock (_serverErrors)
  36. {
  37. _serverErrors.Add(e.TraceMessage.ToString());
  38. }
  39. }
  40. };
  41. _clientLogger.LogMessagePublished += (s, e) =>
  42. {
  43. lock (_clientErrors)
  44. {
  45. if (e.TraceMessage.Level == MqttNetLogLevel.Error)
  46. {
  47. _clientErrors.Add(e.TraceMessage.ToString());
  48. }
  49. }
  50. };
  51. TestContext = testContext;
  52. }
  53. public IMqttClient CreateClient()
  54. {
  55. var client = _mqttFactory.CreateMqttClient(_clientLogger);
  56. _clients.Add(client);
  57. return new TestClientWrapper(client, TestContext);
  58. }
  59. public Task<IMqttServer> StartServerAsync()
  60. {
  61. return StartServerAsync(new MqttServerOptionsBuilder());
  62. }
  63. public async Task<IMqttServer> StartServerAsync(MqttServerOptionsBuilder options)
  64. {
  65. if (Server != null)
  66. {
  67. throw new InvalidOperationException("Server already started.");
  68. }
  69. Server = new TestServerWrapper(_mqttFactory.CreateMqttServer(_serverLogger), TestContext, this);
  70. await Server.StartAsync(options.WithDefaultEndpointPort(ServerPort).Build());
  71. return Server;
  72. }
  73. public Task<IMqttClient> ConnectClientAsync()
  74. {
  75. return ConnectClientAsync(new MqttClientOptionsBuilder() );
  76. }
  77. public async Task<IMqttClient> ConnectClientAsync(MqttClientOptionsBuilder options)
  78. {
  79. if (options == null) throw new ArgumentNullException(nameof(options));
  80. var client = CreateClient();
  81. await client.ConnectAsync(options.WithTcpServer("localhost", ServerPort).Build());
  82. return client;
  83. }
  84. public async Task<IMqttClient> ConnectClientAsync(IMqttClientOptions options)
  85. {
  86. if (options == null) throw new ArgumentNullException(nameof(options));
  87. var client = CreateClient();
  88. await client.ConnectAsync(options);
  89. return client;
  90. }
  91. public void ThrowIfLogErrors()
  92. {
  93. lock (_serverErrors)
  94. {
  95. if (!IgnoreServerLogErrors && _serverErrors.Count > 0)
  96. {
  97. throw new Exception($"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)}).");
  98. }
  99. }
  100. lock (_clientErrors)
  101. {
  102. if (!IgnoreClientLogErrors && _clientErrors.Count > 0)
  103. {
  104. throw new Exception($"Client(s) had {_clientErrors.Count} errors (${string.Join(Environment.NewLine, _clientErrors)}).");
  105. }
  106. }
  107. }
  108. protected override void Dispose(bool disposing)
  109. {
  110. if (disposing)
  111. {
  112. foreach (var mqttClient in _clients)
  113. {
  114. mqttClient?.Dispose();
  115. }
  116. Server?.StopAsync().GetAwaiter().GetResult();
  117. ThrowIfLogErrors();
  118. if (_exceptions.Any())
  119. {
  120. throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions));
  121. }
  122. }
  123. base.Dispose(disposing);
  124. }
  125. public void TrackException(Exception exception)
  126. {
  127. lock (_exceptions)
  128. {
  129. _exceptions.Add(exception);
  130. }
  131. }
  132. }
  133. }