diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index b709835..fc25267 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -9,6 +9,7 @@ using MQTTnet.Core.Client; using MQTTnet.Core.Protocol; using MQTTnet.Core.Server; using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; namespace MQTTnet.Core.Tests @@ -212,7 +213,6 @@ namespace MQTTnet.Core.Tests .BuildServiceProvider(); var s = new MqttFactory(services).CreateMqttServer(); - var retainMessagemanager = services.GetRequiredService(); var receivedMessagesCount = 0; try @@ -223,19 +223,7 @@ namespace MQTTnet.Core.Tests await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); await c1.DisconnectAsync(); - var subscribe = new MqttSubscribePacket() - { - TopicFilters = new List() - { - new TopicFilter("retained", MqttQualityOfServiceLevel.AtMostOnce) - } - }; - - //make shure the retainedMessageManagerreceived the package - while (!(await retainMessagemanager.GetSubscribedMessagesAsync(subscribe)).Any()) - { - await Task.Delay(TimeSpan.FromMilliseconds(10)); - } + await services.WaitForRetainedMessage("retained").TimeoutAfter(TimeSpan.FromSeconds(5)); var c2 = await serverAdapter.ConnectTestClient(s, "c2"); c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; @@ -314,6 +302,8 @@ namespace MQTTnet.Core.Tests await s.StopAsync(); } + await services.WaitForRetainedMessage("retained").TimeoutAfter(TimeSpan.FromSeconds(5)); + s = new MqttFactory(services).CreateMqttServer(options => options.Storage = storage); var receivedMessagesCount = 0; @@ -436,4 +426,25 @@ namespace MQTTnet.Core.Tests Assert.AreEqual(expectedReceivedMessagesCount, receivedMessagesCount); } } + + public static class TestExtensions + { + public static async Task WaitForRetainedMessage(this IServiceProvider services, string topic) + { + var retainMessagemanager = services.GetRequiredService(); + + var subscribe = new MqttSubscribePacket() + { + TopicFilters = new List() + { + new TopicFilter(topic, MqttQualityOfServiceLevel.AtMostOnce) + } + }; + + while (!(await retainMessagemanager.GetSubscribedMessagesAsync(subscribe)).Any()) + { + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } + } + } }