using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; using MQTTnet.Client.Options; using MQTTnet.Diagnostics; using MQTTnet.Server; using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Extensions.Rpc; using MQTTnet.Extensions.Rpc.Options; using MQTTnet.LowLevelClient; namespace MQTTnet.Tests.Mockups { public sealed class TestEnvironment : IDisposable { readonly MqttFactory _mqttFactory = new MqttFactory(); readonly List _clients = new List(); readonly List _serverErrors = new List(); readonly List _clientErrors = new List(); readonly List _exceptions = new List(); public IMqttServer Server { get; private set; } public bool IgnoreClientLogErrors { get; set; } public bool IgnoreServerLogErrors { get; set; } public int ServerPort { get; set; } = 1888; public MqttNetLogger ServerLogger { get; } = new MqttNetLogger("server"); public MqttNetLogger ClientLogger { get; } = new MqttNetLogger("client"); public TestContext TestContext { get; } public TestEnvironment() : this(null) { } public TestEnvironment(TestContext testContext) { TestContext = testContext; ServerLogger.LogMessagePublished += (s, e) => { if (Debugger.IsAttached) { Debug.WriteLine(e.LogMessage.ToString()); } if (e.LogMessage.Level == MqttNetLogLevel.Error) { lock (_serverErrors) { _serverErrors.Add(e.LogMessage.ToString()); } } }; ClientLogger.LogMessagePublished += (s, e) => { if (Debugger.IsAttached) { Debug.WriteLine(e.LogMessage.ToString()); } if (e.LogMessage.Level == MqttNetLogLevel.Error) { lock (_clientErrors) { _clientErrors.Add(e.LogMessage.ToString()); } } }; } public async Task ConnectRpcClientAsync(IMqttRpcClientOptions options) { return new MqttRpcClient(await ConnectClientAsync(), options); } public IMqttClient CreateClient() { lock (_clients) { var client = _mqttFactory.CreateMqttClient(ClientLogger); _clients.Add(client); return new TestClientWrapper(client, TestContext); } } public Task StartServerAsync() { return StartServerAsync(new MqttServerOptionsBuilder()); } public async Task StartServerAsync(MqttServerOptionsBuilder options) { if (options == null) throw new ArgumentNullException(nameof(options)); if (Server != null) { throw new InvalidOperationException("Server already started."); } Server = new TestServerWrapper(_mqttFactory.CreateMqttServer(ServerLogger), TestContext, this); options.WithDefaultEndpointPort(ServerPort); options.WithMaxPendingMessagesPerClient(int.MaxValue); await Server.StartAsync(options.Build()).ConfigureAwait(false); return Server; } public Task ConnectClientAsync() { return ConnectClientAsync(new MqttClientOptionsBuilder()); } public Task ConnectLowLevelClientAsync() { return ConnectLowLevelClientAsync(o => { }); } public async Task ConnectLowLevelClientAsync(Action optionsBuilder) { if (optionsBuilder == null) throw new ArgumentNullException(nameof(optionsBuilder)); var options = new MqttClientOptionsBuilder(); options = options.WithTcpServer("127.0.0.1", ServerPort); optionsBuilder.Invoke(options); var client = new MqttFactory().CreateLowLevelMqttClient(); await client.ConnectAsync(options.Build(), CancellationToken.None).ConfigureAwait(false); return client; } public async Task ConnectClientAsync(Action optionsBuilder) { if (optionsBuilder == null) throw new ArgumentNullException(nameof(optionsBuilder)); var options = new MqttClientOptionsBuilder(); options = options.WithTcpServer("localhost", ServerPort); optionsBuilder.Invoke(options); var client = CreateClient(); await client.ConnectAsync(options.Build()).ConfigureAwait(false); return client; } public async Task ConnectClientAsync(MqttClientOptionsBuilder options) { if (options == null) throw new ArgumentNullException(nameof(options)); options = options.WithTcpServer("localhost", ServerPort); var client = CreateClient(); await client.ConnectAsync(options.Build()).ConfigureAwait(false); return client; } public async Task ConnectClientAsync(IMqttClientOptions options) { if (options == null) throw new ArgumentNullException(nameof(options)); var client = CreateClient(); await client.ConnectAsync(options).ConfigureAwait(false); return client; } public void ThrowIfLogErrors() { lock (_serverErrors) { if (!IgnoreServerLogErrors && _serverErrors.Count > 0) { throw new Exception($"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)})."); } } lock (_clientErrors) { if (!IgnoreClientLogErrors && _clientErrors.Count > 0) { throw new Exception($"Client(s) had {_clientErrors.Count} errors (${string.Join(Environment.NewLine, _clientErrors)})."); } } } public void Dispose() { foreach (var mqttClient in _clients) { try { mqttClient.DisconnectAsync().GetAwaiter().GetResult(); } catch { // This can happen when the test already disconnected the client. } finally { mqttClient?.Dispose(); } } try { Server?.StopAsync().GetAwaiter().GetResult(); } catch { // This can happen when the test already stopped the server. } finally { Server?.Dispose(); } ThrowIfLogErrors(); if (_exceptions.Any()) { throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions)); } } public void TrackException(Exception exception) { lock (_exceptions) { _exceptions.Add(exception); } } } }