diff --git a/MQTTnet.Core/Server/IMqttClientRetainedMessageManager.cs b/MQTTnet.Core/Server/IMqttClientRetainedMessageManager.cs index cd3426e..f400899 100644 --- a/MQTTnet.Core/Server/IMqttClientRetainedMessageManager.cs +++ b/MQTTnet.Core/Server/IMqttClientRetainedMessageManager.cs @@ -10,6 +10,6 @@ namespace MQTTnet.Core.Server Task HandleMessageAsync(string clientId, MqttApplicationMessage applicationMessage); - List GetMessages(MqttSubscribePacket subscribePacket); + Task> GetSubscribedMessagesAsync(MqttSubscribePacket subscribePacket); } } diff --git a/MQTTnet.Core/Server/MqttClientRetainedMessagesManager.cs b/MQTTnet.Core/Server/MqttClientRetainedMessagesManager.cs index d299daa..dbdb1d7 100644 --- a/MQTTnet.Core/Server/MqttClientRetainedMessagesManager.cs +++ b/MQTTnet.Core/Server/MqttClientRetainedMessagesManager.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Packets; using Microsoft.Extensions.Logging; @@ -11,6 +12,7 @@ namespace MQTTnet.Core.Server public sealed class MqttClientRetainedMessagesManager : IMqttClientRetainedMessageManager { private readonly Dictionary _retainedMessages = new Dictionary(); + private readonly SemaphoreSlim _gate = new SemaphoreSlim(1, 1); private readonly ILogger _logger; private readonly MqttServerOptions _options; @@ -27,62 +29,87 @@ namespace MQTTnet.Core.Server return; } + await _gate.WaitAsync(); try { var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync(); - lock (_retainedMessages) + + _retainedMessages.Clear(); + foreach (var retainedMessage in retainedMessages) { - _retainedMessages.Clear(); - foreach (var retainedMessage in retainedMessages) - { - _retainedMessages[retainedMessage.Topic] = retainedMessage; - } + _retainedMessages[retainedMessage.Topic] = retainedMessage; } } catch (Exception exception) { _logger.LogError(new EventId(), exception, "Unhandled exception while loading retained messages."); } + finally + { + _gate.Release(); + } } public async Task HandleMessageAsync(string clientId, MqttApplicationMessage applicationMessage) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - List allRetainedMessages; - lock (_retainedMessages) + await _gate.WaitAsync(); + try { + var saveIsRequired = false; + if (applicationMessage.Payload?.Any() == false) { - _retainedMessages.Remove(applicationMessage.Topic); + saveIsRequired = _retainedMessages.Remove(applicationMessage.Topic); _logger.LogInformation("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic); } else { - _retainedMessages[applicationMessage.Topic] = applicationMessage; - _logger.LogInformation("Client '{0}' updated retained message for topic '{1}'.", clientId, applicationMessage.Topic); + if (!_retainedMessages.ContainsKey(applicationMessage.Topic)) + { + _retainedMessages[applicationMessage.Topic] = applicationMessage; + saveIsRequired = true; + } + else + { + var existingMessage = _retainedMessages[applicationMessage.Topic]; + if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0])) + { + _retainedMessages[applicationMessage.Topic] = applicationMessage; + saveIsRequired = true; + } + } + + _logger.LogInformation("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic); } - allRetainedMessages = new List(_retainedMessages.Values); - } + if (!saveIsRequired) + { + _logger.LogTrace("Skipped saving retained messages because no changes were detected."); + } - try - { - if (_options.Storage != null) + if (saveIsRequired && _options.Storage != null) { - await _options.Storage.SaveRetainedMessagesAsync(allRetainedMessages); + await _options.Storage.SaveRetainedMessagesAsync(_retainedMessages.Values.ToList()); } } catch (Exception exception) { - _logger.LogError(new EventId(), exception, "Unhandled exception while saving retained messages."); + _logger.LogError(new EventId(), exception, "Unhandled exception while handling retained messages."); + } + finally + { + _gate.Release(); } } - public List GetMessages(MqttSubscribePacket subscribePacket) + public async Task> GetSubscribedMessagesAsync(MqttSubscribePacket subscribePacket) { var retainedMessages = new List(); - lock (_retainedMessages) + + await _gate.WaitAsync(); + try { foreach (var retainedMessage in _retainedMessages.Values) { @@ -103,6 +130,10 @@ namespace MQTTnet.Core.Server } } } + finally + { + _gate.Release(); + } return retainedMessages; } diff --git a/MQTTnet.Core/Server/MqttClientSession.cs b/MQTTnet.Core/Server/MqttClientSession.cs index 7ec312a..9b14e58 100644 --- a/MQTTnet.Core/Server/MqttClientSession.cs +++ b/MQTTnet.Core/Server/MqttClientSession.cs @@ -191,8 +191,9 @@ namespace MQTTnet.Core.Server private async Task HandleIncomingSubscribePacketAsync(IMqttCommunicationAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket, ClientId); + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, subscribeResult.ResponsePacket); - EnqueueRetainedMessages(subscribePacket); + await EnqueueSubscribedRetainedMessagesAsync(subscribePacket); if (subscribeResult.CloseConnection) { @@ -201,9 +202,9 @@ namespace MQTTnet.Core.Server } } - private void EnqueueRetainedMessages(MqttSubscribePacket subscribePacket) + private async Task EnqueueSubscribedRetainedMessagesAsync(MqttSubscribePacket subscribePacket) { - var retainedMessages = _clientRetainedMessageManager.GetMessages(subscribePacket); + var retainedMessages = await _clientRetainedMessageManager.GetSubscribedMessagesAsync(subscribePacket); foreach (var publishPacket in retainedMessages) { EnqueuePublishPacket(publishPacket.ToPublishPacket()); diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 86ff600..b709835 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -232,7 +232,7 @@ namespace MQTTnet.Core.Tests }; //make shure the retainedMessageManagerreceived the package - while (!retainMessagemanager.GetMessages(subscribe).Any()) + while (!(await retainMessagemanager.GetSubscribedMessagesAsync(subscribe)).Any()) { await Task.Delay(TimeSpan.FromMilliseconds(10)); } diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/JsonServerStorage.cs b/Tests/MQTTnet.TestApp.UniversalWindows/JsonServerStorage.cs new file mode 100644 index 0000000..9baafec --- /dev/null +++ b/Tests/MQTTnet.TestApp.UniversalWindows/JsonServerStorage.cs @@ -0,0 +1,50 @@ +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using MQTTnet.Core; +using MQTTnet.Core.Server; +using Newtonsoft.Json; + +namespace MQTTnet.TestApp.UniversalWindows +{ + public class JsonServerStorage : IMqttServerStorage + { + private readonly string _filename = Path.Combine(Windows.Storage.ApplicationData.Current.LocalFolder.Path, "Retained.json"); + + public async Task SaveRetainedMessagesAsync(IList messages) + { + await Task.CompletedTask; + + var json = JsonConvert.SerializeObject(messages); + File.WriteAllText(_filename, json); + } + + public async Task> LoadRetainedMessagesAsync() + { + await Task.CompletedTask; + + if (!File.Exists(_filename)) + { + return new List(); + } + + try + { + var json = File.ReadAllText(_filename); + return JsonConvert.DeserializeObject>(json); + } + catch + { + return new List(); + } + } + + public void Clear() + { + if (File.Exists(_filename)) + { + File.Delete(_filename); + } + } + } +} diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj b/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj index 8fda5d3..ce511d6 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj @@ -94,6 +94,7 @@ App.xaml + MainPage.xaml diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml index 64d1d67..4fdbd2e 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml @@ -27,7 +27,7 @@ Clean session: - + TCP WS @@ -77,9 +77,19 @@ 2 (Exactly once) + Received messages: + + + + + + + + - + + @@ -88,6 +98,9 @@ Port: + Persist retained messages in JSON format + Clear previously retained messages on startup + diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs index ab46740..bcea348 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs @@ -88,10 +88,13 @@ namespace MQTTnet.TestApp.UniversalWindows if (_mqttClient != null) { await _mqttClient.DisconnectAsync(); + _mqttClient.ApplicationMessageReceived -= OnApplicationMessageReceived; } var factory = new MqttFactory(); _mqttClient = factory.CreateMqttClient(); + _mqttClient.ApplicationMessageReceived += OnApplicationMessageReceived; + await _mqttClient.ConnectAsync(options); } catch (Exception exception) @@ -100,6 +103,17 @@ namespace MQTTnet.TestApp.UniversalWindows } } + private async void OnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs eventArgs) + { + var item = $"Timestamp: {DateTime.Now:O} | Topic: {eventArgs.ApplicationMessage.Topic} | Payload: {Encoding.UTF8.GetString(eventArgs.ApplicationMessage.Payload)} | QoS: {eventArgs.ApplicationMessage.QualityOfServiceLevel}"; + + await Dispatcher.RunAsync(CoreDispatcherPriority.Normal, () => + { + ReceivedMessages.Items.Add(item); + }); + + } + private async void Publish(object sender, RoutedEventArgs e) { if (_mqttClient == null) @@ -332,9 +346,21 @@ namespace MQTTnet.TestApp.UniversalWindows return; } + JsonServerStorage storage = null; + if (ServerPersistRetainedMessages.IsChecked == true) + { + storage = new JsonServerStorage(); + + if (ServerClearRetainedMessages.IsChecked == true) + { + storage.Clear(); + } + } + _mqttServer = new MqttFactory().CreateMqttServer(o => { o.DefaultEndpointOptions.Port = int.Parse(ServerPort.Text); + o.Storage = storage; }); await _mqttServer.StartAsync(); @@ -350,5 +376,10 @@ namespace MQTTnet.TestApp.UniversalWindows await _mqttServer.StopAsync(); _mqttServer = null; } + + private void ClearReceivedMessages(object sender, RoutedEventArgs e) + { + ReceivedMessages.Items.Clear(); + } } }