diff --git a/Build/MQTTnet.nuspec b/Build/MQTTnet.nuspec index 07ee736..564e47e 100644 --- a/Build/MQTTnet.nuspec +++ b/Build/MQTTnet.nuspec @@ -13,6 +13,7 @@ * [Core] Added support for MQTTv5 packages. * [Client] Added new MQTTv5 features to options builder. * [Client] Added uniform API across all supported MQTT versions (BREAKING CHANGE!) +* [Client] The client will now avoid sending an ACK if an exception has been thrown in message handler (thanks to @ramonsmits). * [Server] Added support for MQTTv5 clients. The server will still return _success_ for all cases at the moment even if more granular codes are available. * [Note] Due to MQTTv5 a lot of new classes were introduced. This required adding new namespaces as well. Most classes are backward compatible but new namespaces must be added. diff --git a/MQTTnet.noUWP.sln b/MQTTnet.noUWP.sln index 8d9130e..f68cbb5 100644 --- a/MQTTnet.noUWP.sln +++ b/MQTTnet.noUWP.sln @@ -3,8 +3,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 VisualStudioVersion = 15.0.27004.2010 MinimumVisualStudioVersion = 10.0.40219.1 -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet.Core.Tests", "Tests\MQTTnet.Core.Tests\MQTTnet.Core.Tests.csproj", "{A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}" -EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Tests", "Tests", "{9248C2E1-B9D6-40BF-81EC-86004D7765B4}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Source", "Source", "{32A630A7-2598-41D7-B625-204CD906F5FB}" @@ -45,6 +43,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet.Extensions.ManagedC EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet.AspNetCore.Tests", "Tests\MQTTnet.AspNetCore.Tests\MQTTnet.AspNetCore.Tests.csproj", "{61B62223-F5D0-48E4-BBD6-2CBA9353CB5E}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet.Tests", "Tests\MQTTnet.Core.Tests\MQTTnet.Tests.csproj", "{9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -57,22 +57,6 @@ Global Release|x86 = Release|x86 EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|Any CPU.Build.0 = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|ARM.ActiveCfg = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|ARM.Build.0 = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|x64.ActiveCfg = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|x64.Build.0 = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|x86.ActiveCfg = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Debug|x86.Build.0 = Debug|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|Any CPU.ActiveCfg = Release|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|Any CPU.Build.0 = Release|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|ARM.ActiveCfg = Release|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|ARM.Build.0 = Release|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|x64.ActiveCfg = Release|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|x64.Build.0 = Release|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|x86.ActiveCfg = Release|Any CPU - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC}.Release|x86.Build.0 = Release|Any CPU {3587E506-55A2-4EB3-99C7-DC01E42D25D2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {3587E506-55A2-4EB3-99C7-DC01E42D25D2}.Debug|Any CPU.Build.0 = Debug|Any CPU {3587E506-55A2-4EB3-99C7-DC01E42D25D2}.Debug|ARM.ActiveCfg = Debug|Any CPU @@ -201,12 +185,27 @@ Global {61B62223-F5D0-48E4-BBD6-2CBA9353CB5E}.Release|x64.Build.0 = Release|Any CPU {61B62223-F5D0-48E4-BBD6-2CBA9353CB5E}.Release|x86.ActiveCfg = Release|Any CPU {61B62223-F5D0-48E4-BBD6-2CBA9353CB5E}.Release|x86.Build.0 = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|ARM.ActiveCfg = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|ARM.Build.0 = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|x64.ActiveCfg = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|x64.Build.0 = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|x86.ActiveCfg = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Debug|x86.Build.0 = Debug|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|Any CPU.Build.0 = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|ARM.ActiveCfg = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|ARM.Build.0 = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|x64.ActiveCfg = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|x64.Build.0 = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|x86.ActiveCfg = Release|Any CPU + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution - {A7FF0C91-25DE-4BA6-B39E-F54E8DADF1CC} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} {3587E506-55A2-4EB3-99C7-DC01E42D25D2} = {32A630A7-2598-41D7-B625-204CD906F5FB} {3D283AAD-AAA8-4339-8394-52F80B6304DB} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} {C6FF8AEA-0855-41EC-A1F3-AC262225BAB9} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} @@ -215,6 +214,7 @@ Global {998D04DD-7CB0-45F5-A393-E2495C16399E} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} {C400533A-8EBA-4F0B-BF4D-295C3708604B} = {12816BCC-AF9E-44A9-9AE5-C246AF2A0587} {61B62223-F5D0-48E4-BBD6-2CBA9353CB5E} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} + {9C7106CA-96B8-4ABE-B3B4-9357AB8ACB41} = {9248C2E1-B9D6-40BF-81EC-86004D7765B4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {07536672-5CBC-4BE3-ACE0-708A431A7894} diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs index e1feb8d..6fcbee9 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClient.cs @@ -5,6 +5,7 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Client; using MQTTnet.Client.Publishing; +using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Internal; @@ -52,6 +53,12 @@ namespace MQTTnet.Extensions.ManagedClient public event EventHandler Connected; public event EventHandler Disconnected; + public IMqttApplicationMessageHandler ReceivedApplicationMessageHandler + { + get => _mqttClient.ReceivedApplicationMessageHandler; + set => _mqttClient.ReceivedApplicationMessageHandler = value; + } + public event EventHandler ApplicationMessageReceived; public event EventHandler ApplicationMessageProcessed; public event EventHandler ApplicationMessageSkipped; @@ -372,12 +379,12 @@ namespace MQTTnet.Extensions.ManagedClient { if (unsubscriptions.Any()) { - await _mqttClient.UnsubscribeAsync(unsubscriptions).ConfigureAwait(false); + await _mqttClient.UnsubscribeAsync(unsubscriptions.ToArray()).ConfigureAwait(false); } if (subscriptions.Any()) { - await _mqttClient.SubscribeAsync(subscriptions).ConfigureAwait(false); + await _mqttClient.SubscribeAsync(subscriptions.ToArray()).ConfigureAwait(false); } } catch (Exception exception) diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 44ef5be..ded2a1f 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -174,12 +174,18 @@ namespace MQTTnet.Adapter private async Task ReceiveAsync(CancellationToken cancellationToken) { - var fixedHeader = await _packetReader.ReadFixedHeaderAsync(_fixedHeaderBuffer, cancellationToken).ConfigureAwait(false); + var readFixedHeaderResult = await _packetReader.ReadFixedHeaderAsync(_fixedHeaderBuffer, cancellationToken).ConfigureAwait(false); try { + if (readFixedHeaderResult.ConnectionClosed) + { + return null; + } + ReadingPacketStarted?.Invoke(this, EventArgs.Empty); + var fixedHeader = readFixedHeaderResult.FixedHeader; if (fixedHeader.RemainingLength == 0) { return new ReceivedMqttPacket(fixedHeader.Flags, null, 2); @@ -205,9 +211,16 @@ namespace MQTTnet.Adapter var readBytes = _channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken).ConfigureAwait(false).GetAwaiter().GetResult(); #endif - cancellationToken.ThrowIfCancellationRequested(); - ExceptionHelper.ThrowIfGracefulSocketClose(readBytes); + if (cancellationToken.IsCancellationRequested) + { + return null; + } + if (readBytes == 0) + { + return null; + } + bodyOffset += readBytes; } while (bodyOffset < body.Length); diff --git a/Source/MQTTnet/Client/IMqttClient.cs b/Source/MQTTnet/Client/IMqttClient.cs index 4e9cbc1..3c9ded3 100644 --- a/Source/MQTTnet/Client/IMqttClient.cs +++ b/Source/MQTTnet/Client/IMqttClient.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Threading.Tasks; using MQTTnet.Client.Connecting; using MQTTnet.Client.Disconnecting; @@ -20,7 +19,7 @@ namespace MQTTnet.Client Task ConnectAsync(IMqttClientOptions options); Task DisconnectAsync(MqttClientDisconnectOptions options); - Task SubscribeAsync(IEnumerable topicFilters); - Task UnsubscribeAsync(IEnumerable topics); + Task SubscribeAsync(MqttClientSubscribeOptions options); + Task UnsubscribeAsync(MqttClientUnsubscribeOptions options); } } \ No newline at end of file diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index be29f5d..90c5fa0 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -1,20 +1,19 @@ using System; -using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Client.Connecting; using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; -using MQTTnet.Client.PacketDispatcher; using MQTTnet.Client.Publishing; +using MQTTnet.Client.Receiving; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Formatter; +using MQTTnet.PacketDispatcher; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -23,16 +22,18 @@ namespace MQTTnet.Client public class MqttClient : IMqttClient { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); + private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly Stopwatch _sendTracker = new Stopwatch(); private readonly object _disconnectLock = new object(); - private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly IMqttClientAdapterFactory _adapterFactory; private readonly IMqttNetChildLogger _logger; private CancellationTokenSource _cancellationTokenSource; - internal Task _packetReceiverTask; - internal Task _keepAliveMessageSenderTask; + private Task _packetReceiverTask; + private Task _keepAlivePacketsSenderTask; + private Task _backgroundWorkerTask; + private IMqttChannelAdapter _adapter; private bool _cleanDisconnectInitiated; private long _disconnectGate; @@ -47,6 +48,8 @@ namespace MQTTnet.Client public event EventHandler Connected; public event EventHandler Disconnected; + + public IMqttApplicationMessageHandler ReceivedApplicationMessageHandler { get; set; } public event EventHandler ApplicationMessageReceived; public bool IsConnected { get; private set; } @@ -66,28 +69,31 @@ namespace MQTTnet.Client _packetIdentifierProvider.Reset(); _packetDispatcher.Reset(); - var cancellationTokenSource = new CancellationTokenSource(); - _cancellationTokenSource = cancellationTokenSource; + _cancellationTokenSource = new CancellationTokenSource(); + var cancellationToken = _cancellationTokenSource.Token; _disconnectGate = 0; - _adapter = _adapterFactory.CreateClientAdapter(options, _logger); + var adapter = _adapterFactory.CreateClientAdapter(options, _logger); + _adapter = adapter; _logger.Verbose($"Trying to connect with server ({Options.ChannelOptions})."); - await _adapter.ConnectAsync(Options.CommunicationTimeout, cancellationTokenSource.Token).ConfigureAwait(false); + await _adapter.ConnectAsync(Options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); _logger.Verbose("Connection with server established."); - StartReceivingPackets(cancellationTokenSource.Token); + _packetReceiverTask = Task.Run(() => TryReceivePacketsAsync(cancellationToken), cancellationToken); - var connectResult = await AuthenticateAsync(options.WillMessage, cancellationTokenSource.Token).ConfigureAwait(false); + var connectResult = await AuthenticateAsync(adapter, options.WillMessage, cancellationToken).ConfigureAwait(false); _logger.Verbose("MQTT connection with server established."); _sendTracker.Restart(); if (Options.KeepAlivePeriod != TimeSpan.Zero) { - StartSendingKeepAliveMessages(cancellationTokenSource.Token); + _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(cancellationToken), cancellationToken); } + _backgroundWorkerTask = Task.Run(() => TryProcessReceivedPacketsAsync(cancellationToken), cancellationToken); + IsConnected = true; Connected?.Invoke(this, new MqttClientConnectedEventArgs(connectResult)); @@ -129,35 +135,27 @@ namespace MQTTnet.Client } } - public async Task SubscribeAsync(IEnumerable topicFilters) + public async Task SubscribeAsync(MqttClientSubscribeOptions options) { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + if (options == null) throw new ArgumentNullException(nameof(options)); ThrowIfNotConnected(); - var subscribePacket = new MqttSubscribePacket - { - PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier() - }; - - subscribePacket.TopicFilters.AddRange(topicFilters); + var subscribePacket = _adapter.PacketFormatterAdapter.DataConverter.CreateSubscribePacket(options); + subscribePacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); var subAckPacket = await SendAndReceiveAsync(subscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); return _adapter.PacketFormatterAdapter.DataConverter.CreateClientSubscribeResult(subscribePacket, subAckPacket); } - public async Task UnsubscribeAsync(IEnumerable topicFilters) + public async Task UnsubscribeAsync(MqttClientUnsubscribeOptions options) { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + if (options == null) throw new ArgumentNullException(nameof(options)); ThrowIfNotConnected(); - var unsubscribePacket = new MqttUnsubscribePacket - { - PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier() - }; - - unsubscribePacket.TopicFilters.AddRange(topicFilters); + var unsubscribePacket = _adapter.PacketFormatterAdapter.DataConverter.CreateUnsubscribePacket(options); + unsubscribePacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); var unsubAckPacket = await SendAndReceiveAsync(unsubscribePacket, _cancellationTokenSource.Token).ConfigureAwait(false); return _adapter.PacketFormatterAdapter.DataConverter.CreateClientUnsubscribeResult(unsubscribePacket, unsubAckPacket); @@ -181,7 +179,7 @@ namespace MQTTnet.Client } case MqttQualityOfServiceLevel.AtLeastOnce: { - publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); + publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); var response = await SendAndReceiveAsync(publishPacket, _cancellationTokenSource.Token).ConfigureAwait(false); var result = new MqttClientPublishResult(); @@ -213,15 +211,14 @@ namespace MQTTnet.Client _adapter = null; } - private async Task AuthenticateAsync(MqttApplicationMessage willApplicationMessage, CancellationToken cancellationToken) + private async Task AuthenticateAsync(IMqttChannelAdapter channelAdapter, MqttApplicationMessage willApplicationMessage, CancellationToken cancellationToken) { - var connectPacket = _adapter.PacketFormatterAdapter.DataConverter.CreateConnectPacket( + var connectPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnectPacket( willApplicationMessage, Options); var connAckPacket = await SendAndReceiveAsync(connectPacket, cancellationToken).ConfigureAwait(false); - - var result = _adapter.PacketFormatterAdapter.DataConverter.CreateClientConnectResult(connAckPacket); + var result = channelAdapter.PacketFormatterAdapter.DataConverter.CreateClientConnectResult(connAckPacket); if (result.ResultCode != MqttClientConnectResultCode.Success) { @@ -257,7 +254,8 @@ namespace MQTTnet.Client } await WaitForTaskAsync(_packetReceiverTask, sender).ConfigureAwait(false); - await WaitForTaskAsync(_keepAliveMessageSenderTask, sender).ConfigureAwait(false); + await WaitForTaskAsync(_keepAlivePacketsSenderTask, sender).ConfigureAwait(false); + await WaitForTaskAsync(_backgroundWorkerTask, sender).ConfigureAwait(false); _logger.Verbose("Disconnected from adapter."); } @@ -316,26 +314,22 @@ namespace MQTTnet.Client identifier = packetWithIdentifier.PacketIdentifier.Value; } - var packetAwaiter = _packetDispatcher.AddPacketAwaiter(identifier); - try + using (var packetAwaiter = _packetDispatcher.AddPacketAwaiter(identifier)) { - await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false); - return await packetAwaiter.WaitOneAsync(Options.CommunicationTimeout); - - //return (TResponsePacket)await Internal.TaskExtensions.TimeoutAfterAsync(ct => packetAwaiter.Task, Options.CommunicationTimeout, cancellationToken).ConfigureAwait(false); - } - catch (MqttCommunicationTimedOutException) - { - _logger.Warning(null, "Timeout while waiting for packet of type '{0}'.", typeof(TResponsePacket).Namespace); - throw; - } - finally - { - _packetDispatcher.RemovePacketAwaiter(identifier); + try + { + await _adapter.SendPacketAsync(requestPacket, cancellationToken).ConfigureAwait(false); + return await packetAwaiter.WaitOneAsync(Options.CommunicationTimeout).ConfigureAwait(false); + } + catch (MqttCommunicationTimedOutException) + { + _logger.Warning(null, "Timeout while waiting for packet of type '{0}'.", typeof(TResponsePacket).Namespace); + throw; + } } } - private async Task SendKeepAliveMessagesAsync(CancellationToken cancellationToken) + private async Task TrySendKeepAliveMessagesAsync(CancellationToken cancellationToken) { try { @@ -380,7 +374,7 @@ namespace MQTTnet.Client if (!DisconnectIsPending()) { - await DisconnectInternalAsync(_keepAliveMessageSenderTask, exception).ConfigureAwait(false); + await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception).ConfigureAwait(false); } } finally @@ -389,7 +383,7 @@ namespace MQTTnet.Client } } - private async Task ReceivePacketsAsync(CancellationToken cancellationToken) + private async Task TryReceivePacketsAsync(CancellationToken cancellationToken) { try { @@ -399,10 +393,17 @@ namespace MQTTnet.Client { var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false); - if (packet != null && !cancellationToken.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) + { + return; + } + + if (packet == null && !DisconnectIsPending()) { - await ProcessReceivedPacketAsync(packet, cancellationToken).ConfigureAwait(false); + await DisconnectInternalAsync(_packetReceiverTask, null).ConfigureAwait(false); } + + _packetDispatcher.Dispatch(packet); } } catch (Exception exception) @@ -437,79 +438,146 @@ namespace MQTTnet.Client } } - private Task ProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) + private async Task TryProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { - if (packet is MqttPublishPacket publishPacket) + try { - return ProcessReceivedPublishPacketAsync(publishPacket, cancellationToken); + if (packet is MqttPublishPacket publishPacket) + { + await TryProcessReceivedPublishPacketAsync(publishPacket, cancellationToken).ConfigureAwait(false); + } + else if (packet is MqttPingReqPacket) + { + await SendAsync(new MqttPingRespPacket(), cancellationToken).ConfigureAwait(false); + } + else if (packet is MqttDisconnectPacket) + { + await DisconnectAsync(null).ConfigureAwait(false); + } + else + { + throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + } } - - if (packet is MqttPingReqPacket) + catch (Exception exception) { - return SendAsync(new MqttPingRespPacket(), cancellationToken); - } + if (_cleanDisconnectInitiated) + { + return; + } - if (packet is MqttDisconnectPacket) - { - return DisconnectAsync(null); - } + if (exception is OperationCanceledException) + { + } + else if (exception is MqttCommunicationException) + { + _logger.Warning(exception, "MQTT communication exception while receiving packets."); + } + else + { + _logger.Error(exception, "Unhandled exception while receiving packets."); + } - if (packet is MqttPubRelPacket pubRelPacket) - { - return ProcessReceivedPubRelPacket(pubRelPacket, cancellationToken); - } + _packetDispatcher.Dispatch(exception); - _packetDispatcher.Dispatch(packet); - return Task.FromResult(0); + if (!DisconnectIsPending()) + { + await DisconnectInternalAsync(_packetReceiverTask, exception).ConfigureAwait(false); + } + } } - private Task ProcessReceivedPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) + private async Task TryProcessReceivedPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) + try { - FireApplicationMessageReceivedEvent(publishPacket); - return Task.FromResult(0); - } + if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) + { + await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); + } + else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) + { + await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) - { - FireApplicationMessageReceivedEvent(publishPacket); - return SendAsync(new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier, ReasonCode = MqttPubAckReasonCode.Success }, cancellationToken); - } + await SendAsync(new MqttPubAckPacket + { + PacketIdentifier = publishPacket.PacketIdentifier, + ReasonCode = MqttPubAckReasonCode.Success + }, cancellationToken).ConfigureAwait(false); + } + else if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) + { + var pubRecPacket = new MqttPubRecPacket + { + PacketIdentifier = publishPacket.PacketIdentifier, + ReasonCode = MqttPubRecReasonCode.Success + }; - if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) + var pubRelPacket = await SendAndReceiveAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); + + // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) + await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); + + await SendAsync(new MqttPubCompPacket + { + PacketIdentifier = pubRelPacket.PacketIdentifier, + ReasonCode = MqttPubCompReasonCode.Success + }, cancellationToken).ConfigureAwait(false); + } + else + { + throw new MqttProtocolViolationException("Received a not supported QoS level."); + } + } + catch (Exception exception) { - // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) - FireApplicationMessageReceivedEvent(publishPacket); - return SendAsync(new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier, ReasonCode = MqttPubRecReasonCode.Success }, cancellationToken); + _logger.Error(exception, "Unhandled exception while handling application message."); } - - throw new MqttCommunicationException("Received a not supported QoS level."); } - private Task ProcessReceivedPubRelPacket(MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) + private async Task TryProcessReceivedPacketsAsync(CancellationToken cancellationToken) { - var response = new MqttPubCompPacket + try { - PacketIdentifier = pubRelPacket.PacketIdentifier, - ReasonCode = MqttPubCompReasonCode.Success - }; + while (!cancellationToken.IsCancellationRequested) + { + var packet = _packetDispatcher.Take(cancellationToken); + + if (cancellationToken.IsCancellationRequested) + { + return; + } - return SendAsync(response, cancellationToken); + await TryProcessReceivedPacketAsync(packet, cancellationToken).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + } + catch (Exception exception) + { + _logger.Error(exception, "Error while processing packet."); + } } private async Task PublishExactlyOnceAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { - publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); + publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); var pubRecPacket = await SendAndReceiveAsync(publishPacket, cancellationToken).ConfigureAwait(false); + + // TODO: Check response code. + var pubRelPacket = new MqttPubRelPacket { - PacketIdentifier = pubRecPacket.PacketIdentifier, + PacketIdentifier = publishPacket.PacketIdentifier, ReasonCode = MqttPubRelReasonCode.Success }; var pubCompPacket = await SendAndReceiveAsync(pubRelPacket, cancellationToken).ConfigureAwait(false); + + // TODO: Check response code. + var result = new MqttClientPublishResult(); if (pubRecPacket.ReasonCode != null) @@ -520,35 +588,38 @@ namespace MQTTnet.Client return result; } - private void StartReceivingPackets(CancellationToken cancellationToken) + ////private void StartReceivingPackets(CancellationToken cancellationToken) + ////{ + //// _packetReceiverTask = Task.Factory.StartNew( + //// () => TryReceivePacketsAsync(cancellationToken), + //// cancellationToken, + //// TaskCreationOptions.LongRunning, + //// TaskScheduler.Default).Unwrap(); + ////} + + ////private void StartSendingKeepAliveMessages(CancellationToken cancellationToken) + ////{ + //// _keepAlivePacketsSenderTask = Task.Factory.StartNew( + //// () => TrySendKeepAliveMessagesAsync(cancellationToken), + //// cancellationToken, + //// TaskCreationOptions.LongRunning, + //// TaskScheduler.Default).Unwrap(); + ////} + + private Task HandleReceivedApplicationMessageAsync(MqttPublishPacket publishPacket) { - _packetReceiverTask = Task.Factory.StartNew( - () => ReceivePacketsAsync(cancellationToken), - cancellationToken, - TaskCreationOptions.LongRunning, - TaskScheduler.Default).Unwrap(); - } + var applicationMessage = _adapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket); - private void StartSendingKeepAliveMessages(CancellationToken cancellationToken) - { - _keepAliveMessageSenderTask = Task.Factory.StartNew( - () => SendKeepAliveMessagesAsync(cancellationToken), - cancellationToken, - TaskCreationOptions.LongRunning, - TaskScheduler.Default).Unwrap(); - } + ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(Options.ClientId, applicationMessage)); - private void FireApplicationMessageReceivedEvent(MqttPublishPacket publishPacket) - { - try + var handler = ReceivedApplicationMessageHandler; + if (handler != null) { - var applicationMessage = _adapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket); - ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(Options.ClientId, applicationMessage)); - } - catch (Exception exception) - { - _logger.Error(exception, "Unhandled exception while handling application message."); + return handler.HandleApplicationMessageAsync( + new MqttApplicationMessageHandlerContext(Options.ClientId, applicationMessage)); } + + return Task.FromResult(0); } private static async Task WaitForTaskAsync(Task task, Task sender) diff --git a/Source/MQTTnet/Client/MqttClientExtensions.cs b/Source/MQTTnet/Client/MqttClientExtensions.cs index b399ef6..ef6f5ab 100644 --- a/Source/MQTTnet/Client/MqttClientExtensions.cs +++ b/Source/MQTTnet/Client/MqttClientExtensions.cs @@ -1,5 +1,6 @@ using System; using System.Threading.Tasks; +using MQTTnet.Client.Receiving; using MQTTnet.Client.Subscribing; using MQTTnet.Client.Unsubscribing; using MQTTnet.Protocol; @@ -8,6 +9,51 @@ namespace MQTTnet.Client { public static class MqttClientExtensions { + public static IMqttClient UseReceivedApplicationMessageHandler(this IMqttClient client, Func handler) + { + if (handler == null) + { + client.ReceivedApplicationMessageHandler = null; + return client; + } + + client.ReceivedApplicationMessageHandler = new MqttApplicationMessageHandlerDelegate(handler); + + return client; + } + + public static IMqttClient UseReceivedApplicationMessageHandler(this IMqttClient client, Action handler) + { + if (handler == null) + { + client.ReceivedApplicationMessageHandler = null; + return client; + } + + client.ReceivedApplicationMessageHandler = new MqttApplicationMessageHandlerDelegate(handler); + + return client; + } + + public static IMqttClient UseReceivedApplicationMessageHandler(this IMqttClient client, IMqttApplicationMessageHandler handler) + { + client.ReceivedApplicationMessageHandler = handler; + + return client; + } + + public static Task ReconnectAsync(this IMqttClient client) + { + if (client == null) throw new ArgumentNullException(nameof(client)); + + if (client.Options == null) + { + throw new InvalidOperationException("_ReconnectAsync_ can be used only if _ConnectAsync_ was called before."); + } + + return client.ConnectAsync(client.Options); + } + public static Task DisconnectAsync(this IMqttClient client) { if (client == null) throw new ArgumentNullException(nameof(client)); @@ -20,7 +66,10 @@ namespace MQTTnet.Client if (client == null) throw new ArgumentNullException(nameof(client)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return client.SubscribeAsync(topicFilters); + var options = new MqttClientSubscribeOptions(); + options.TopicFilters.AddRange(topicFilters); + + return client.SubscribeAsync(options); } public static Task SubscribeAsync(this IMqttClient client, string topic, MqttQualityOfServiceLevel qualityOfServiceLevel) @@ -44,7 +93,10 @@ namespace MQTTnet.Client if (client == null) throw new ArgumentNullException(nameof(client)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - return client.UnsubscribeAsync(topicFilters); + var options = new MqttClientUnsubscribeOptions(); + options.TopicFilters.AddRange(topicFilters); + + return client.UnsubscribeAsync(options); } } } \ No newline at end of file diff --git a/Source/MQTTnet/Client/MqttPacketIdentifierProvider.cs b/Source/MQTTnet/Client/MqttPacketIdentifierProvider.cs index dc09918..bfee386 100644 --- a/Source/MQTTnet/Client/MqttPacketIdentifierProvider.cs +++ b/Source/MQTTnet/Client/MqttPacketIdentifierProvider.cs @@ -13,7 +13,7 @@ } } - public ushort GetNewPacketIdentifier() + public ushort GetNextPacketIdentifier() { lock (_syncRoot) { diff --git a/Source/MQTTnet/Client/Options/MqttClientOptions.cs b/Source/MQTTnet/Client/Options/MqttClientOptions.cs index a76ecf5..be9dc7b 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptions.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptions.cs @@ -7,7 +7,7 @@ namespace MQTTnet.Client.Options { public string ClientId { get; set; } = Guid.NewGuid().ToString("N"); public bool CleanSession { get; set; } = true; - public IMqttClientCredentials Credentials { get; set; } = new MqttClientCredentials(); + public IMqttClientCredentials Credentials { get; set; } public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; public IMqttClientChannelOptions ChannelOptions { get; set; } diff --git a/Source/MQTTnet/Client/PacketDispatcher/MqttPacketAwaiter.cs b/Source/MQTTnet/Client/PacketDispatcher/MqttPacketAwaiter.cs deleted file mode 100644 index 292f231..0000000 --- a/Source/MQTTnet/Client/PacketDispatcher/MqttPacketAwaiter.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using MQTTnet.Packets; - -namespace MQTTnet.Client.PacketDispatcher -{ - public sealed class MqttPacketAwaiter : IMqttPacketAwaiter where TPacket : MqttBasePacket - { - private readonly TaskCompletionSource _packet = new TaskCompletionSource(); - - public async Task WaitOneAsync(TimeSpan timeout) - { - using (var timeoutToken = new CancellationTokenSource(timeout)) - { - timeoutToken.Token.Register(() => _packet.TrySetCanceled()); - - var packet = await _packet.Task.ConfigureAwait(false); - return (TPacket)packet; - } - } - - public void Complete(MqttBasePacket packet) - { - if (packet == null) throw new ArgumentNullException(nameof(packet)); - - _packet.TrySetResult(packet); - } - - public void Fail(Exception exception) - { - if (exception == null) throw new ArgumentNullException(nameof(exception)); - - _packet.TrySetException(exception); - } - } -} \ No newline at end of file diff --git a/Source/MQTTnet/Client/Receiving/IMqttApplicationMessageHandler.cs b/Source/MQTTnet/Client/Receiving/IMqttApplicationMessageHandler.cs new file mode 100644 index 0000000..fdc5bf6 --- /dev/null +++ b/Source/MQTTnet/Client/Receiving/IMqttApplicationMessageHandler.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Client.Receiving +{ + public interface IMqttApplicationMessageHandler + { + Task HandleApplicationMessageAsync(MqttApplicationMessageHandlerContext context); + } +} diff --git a/Source/MQTTnet/Client/Receiving/MqttApplicationMessageHandlerContext.cs b/Source/MQTTnet/Client/Receiving/MqttApplicationMessageHandlerContext.cs new file mode 100644 index 0000000..87fb959 --- /dev/null +++ b/Source/MQTTnet/Client/Receiving/MqttApplicationMessageHandlerContext.cs @@ -0,0 +1,15 @@ +namespace MQTTnet.Client.Receiving +{ + public class MqttApplicationMessageHandlerContext + { + public MqttApplicationMessageHandlerContext(string senderClientId, MqttApplicationMessage applicationMessage) + { + SenderClientId = senderClientId; + ApplicationMessage = applicationMessage; + } + + public string SenderClientId { get; } + + public MqttApplicationMessage ApplicationMessage { get; } + } +} diff --git a/Source/MQTTnet/Client/Receiving/MqttApplicationMessageHandlerDelegate.cs b/Source/MQTTnet/Client/Receiving/MqttApplicationMessageHandlerDelegate.cs new file mode 100644 index 0000000..cde8944 --- /dev/null +++ b/Source/MQTTnet/Client/Receiving/MqttApplicationMessageHandlerDelegate.cs @@ -0,0 +1,31 @@ +using System; +using System.Threading.Tasks; + +namespace MQTTnet.Client.Receiving +{ + public class MqttApplicationMessageHandlerDelegate : IMqttApplicationMessageHandler + { + private readonly Func _handler; + + public MqttApplicationMessageHandlerDelegate(Action handler) + { + if (handler == null) throw new ArgumentNullException(nameof(handler)); + + _handler = context => + { + handler(context); + return Task.FromResult(0); + }; + } + + public MqttApplicationMessageHandlerDelegate(Func handler) + { + _handler = handler ?? throw new ArgumentNullException(nameof(handler)); + } + + public Task HandleApplicationMessageAsync(MqttApplicationMessageHandlerContext context) + { + return _handler(context); + } + } +} diff --git a/Source/MQTTnet/Client/Subscribing/MqttClientSubscribeOptions.cs b/Source/MQTTnet/Client/Subscribing/MqttClientSubscribeOptions.cs new file mode 100644 index 0000000..8d2adbf --- /dev/null +++ b/Source/MQTTnet/Client/Subscribing/MqttClientSubscribeOptions.cs @@ -0,0 +1,12 @@ +using System.Collections.Generic; +using MQTTnet.Packets; + +namespace MQTTnet.Client.Subscribing +{ + public class MqttClientSubscribeOptions + { + public List TopicFilters { get; set; } = new List(); + + public List UserProperties { get; set; } = new List(); + } +} diff --git a/Source/MQTTnet/Client/Unsubscribing/MqttClientUnsubscribeOptions.cs b/Source/MQTTnet/Client/Unsubscribing/MqttClientUnsubscribeOptions.cs new file mode 100644 index 0000000..73c21ef --- /dev/null +++ b/Source/MQTTnet/Client/Unsubscribing/MqttClientUnsubscribeOptions.cs @@ -0,0 +1,12 @@ +using System.Collections.Generic; +using MQTTnet.Packets; + +namespace MQTTnet.Client.Unsubscribing +{ + public class MqttClientUnsubscribeOptions + { + public List TopicFilters { get; set; } = new List(); + + public List UserProperties { get; set; } = new List(); + } +} diff --git a/Source/MQTTnet/Exceptions/MqttCommunicationClosedGracefullyException.cs b/Source/MQTTnet/Exceptions/MqttCommunicationClosedGracefullyException.cs deleted file mode 100644 index 21a66b9..0000000 --- a/Source/MQTTnet/Exceptions/MqttCommunicationClosedGracefullyException.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace MQTTnet.Exceptions -{ - public class MqttCommunicationClosedGracefullyException : MqttCommunicationException - { - } -} diff --git a/Source/MQTTnet/Formatter/IMqttDataConverter.cs b/Source/MQTTnet/Formatter/IMqttDataConverter.cs index 44516ee..19c2a28 100644 --- a/Source/MQTTnet/Formatter/IMqttDataConverter.cs +++ b/Source/MQTTnet/Formatter/IMqttDataConverter.cs @@ -19,5 +19,9 @@ namespace MQTTnet.Formatter MqttClientSubscribeResult CreateClientSubscribeResult(MqttSubscribePacket subscribePacket, MqttSubAckPacket subAckPacket); MqttClientUnsubscribeResult CreateClientUnsubscribeResult(MqttUnsubscribePacket unsubscribePacket, MqttUnsubAckPacket unsubAckPacket); + + MqttSubscribePacket CreateSubscribePacket(MqttClientSubscribeOptions options); + + MqttUnsubscribePacket CreateUnsubscribePacket(MqttClientUnsubscribeOptions options); } } diff --git a/Source/MQTTnet/Formatter/MqttPacketReader.cs b/Source/MQTTnet/Formatter/MqttPacketReader.cs index eec643a..b4e7b67 100644 --- a/Source/MQTTnet/Formatter/MqttPacketReader.cs +++ b/Source/MQTTnet/Formatter/MqttPacketReader.cs @@ -18,7 +18,7 @@ namespace MQTTnet.Formatter _channel = channel ?? throw new ArgumentNullException(nameof(channel)); } - public async Task ReadFixedHeaderAsync(byte[] fixedHeaderBuffer, CancellationToken cancellationToken) + public async Task ReadFixedHeaderAsync(byte[] fixedHeaderBuffer, CancellationToken cancellationToken) { // The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. // So in all cases at least 2 bytes must be read for a complete MQTT packet. @@ -32,15 +32,25 @@ namespace MQTTnet.Formatter var bytesRead = await _channel.ReadAsync(buffer, totalBytesRead, buffer.Length - totalBytesRead, cancellationToken).ConfigureAwait(false); cancellationToken.ThrowIfCancellationRequested(); - ExceptionHelper.ThrowIfGracefulSocketClose(bytesRead); + if (bytesRead == 0) + { + return new ReadFixedHeaderResult + { + ConnectionClosed = true + }; + } + totalBytesRead += bytesRead; } var hasRemainingLength = buffer[1] != 0; if (!hasRemainingLength) { - return new MqttFixedHeader(buffer[0], 0, totalBytesRead); + return new ReadFixedHeaderResult + { + FixedHeader = new MqttFixedHeader(buffer[0], 0, totalBytesRead) + }; } #if WINDOWS_UWP @@ -54,12 +64,23 @@ namespace MQTTnet.Formatter var bodyLength = ReadBodyLength(buffer[1], cancellationToken); #endif - totalBytesRead += bodyLength; - return new MqttFixedHeader(buffer[0], bodyLength, totalBytesRead); + if (!bodyLength.HasValue) + { + return new ReadFixedHeaderResult + { + ConnectionClosed = true + }; + } + + totalBytesRead += bodyLength.Value; + return new ReadFixedHeaderResult + { + FixedHeader = new MqttFixedHeader(buffer[0], bodyLength.Value, totalBytesRead) + }; } #if !WINDOWS_UWP - private int ReadBodyLength(byte initialEncodedByte, CancellationToken cancellationToken) + private int? ReadBodyLength(byte initialEncodedByte, CancellationToken cancellationToken) { var offset = 0; var multiplier = 128; @@ -74,9 +95,18 @@ namespace MQTTnet.Formatter throw new MqttProtocolViolationException("Remaining length is invalid."); } - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return null; + } + + var buffer = ReadByte(cancellationToken); + if (!buffer.HasValue) + { + return null; + } - encodedByte = ReadByte(cancellationToken); + encodedByte = buffer.Value; value += (encodedByte & 127) * multiplier; multiplier *= 128; @@ -85,15 +115,18 @@ namespace MQTTnet.Formatter return value; } - private byte ReadByte(CancellationToken cancellationToken) + private byte? ReadByte(CancellationToken cancellationToken) { var readCount = _channel.ReadAsync(_singleByteBuffer, 0, 1, cancellationToken).ConfigureAwait(false).GetAwaiter().GetResult(); - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return null; + } - if (readCount <= 0) + if (readCount == 0) { - ExceptionHelper.ThrowGracefulSocketClose(); + return null; } return _singleByteBuffer[0]; @@ -101,7 +134,7 @@ namespace MQTTnet.Formatter #else - private async Task ReadBodyLengthAsync(byte initialEncodedByte, CancellationToken cancellationToken) + private async Task ReadBodyLengthAsync(byte initialEncodedByte, CancellationToken cancellationToken) { var offset = 0; var multiplier = 128; @@ -116,9 +149,18 @@ namespace MQTTnet.Formatter throw new MqttProtocolViolationException("Remaining length is invalid."); } - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return null; + } + + var buffer = await ReadByteAsync(cancellationToken).ConfigureAwait(false); + if (!buffer.HasValue) + { + return null; + } - encodedByte = await ReadByteAsync(cancellationToken).ConfigureAwait(false); + encodedByte = buffer.Value; value += (encodedByte & 127) * multiplier; multiplier *= 128; @@ -127,15 +169,18 @@ namespace MQTTnet.Formatter return value; } - private async Task ReadByteAsync(CancellationToken cancellationToken) + private async Task ReadByteAsync(CancellationToken cancellationToken) { var readCount = await _channel.ReadAsync(_singleByteBuffer, 0, 1, cancellationToken).ConfigureAwait(false); - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return null; + } - if (readCount <= 0) + if (readCount == 0) { - ExceptionHelper.ThrowGracefulSocketClose(); + return null; } return _singleByteBuffer[0]; diff --git a/Source/MQTTnet/Formatter/ReadFixedHeaderResult.cs b/Source/MQTTnet/Formatter/ReadFixedHeaderResult.cs new file mode 100644 index 0000000..30ad73e --- /dev/null +++ b/Source/MQTTnet/Formatter/ReadFixedHeaderResult.cs @@ -0,0 +1,9 @@ +namespace MQTTnet.Formatter +{ + public class ReadFixedHeaderResult + { + public bool ConnectionClosed { get; set; } + + public MqttFixedHeader FixedHeader { get; set; } + } +} diff --git a/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs b/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs index b643675..1fad36b 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV310DataConverter.cs @@ -143,5 +143,35 @@ namespace MQTTnet.Formatter.V3 return result; } + + public MqttSubscribePacket CreateSubscribePacket(MqttClientSubscribeOptions options) + { + if (options == null) throw new ArgumentNullException(nameof(options)); + + if (options.UserProperties?.Any() == true) + { + throw new MqttProtocolViolationException("User properties are not supported in MQTT version 3."); + } + + var subscribePacket = new MqttSubscribePacket(); + subscribePacket.TopicFilters.AddRange(options.TopicFilters); + + return subscribePacket; + } + + public MqttUnsubscribePacket CreateUnsubscribePacket(MqttClientUnsubscribeOptions options) + { + if (options == null) throw new ArgumentNullException(nameof(options)); + + if (options.UserProperties?.Any() == true) + { + throw new MqttProtocolViolationException("User properties are not supported in MQTT version 3."); + } + + var unsubscribePacket = new MqttUnsubscribePacket(); + unsubscribePacket.TopicFilters.AddRange(options.TopicFilters); + + return unsubscribePacket; + } } } diff --git a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs index 5a26a67..b5522e1 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500DataConverter.cs @@ -128,5 +128,35 @@ namespace MQTTnet.Formatter.V5 return result; } + + public MqttSubscribePacket CreateSubscribePacket(MqttClientSubscribeOptions options) + { + if (options == null) throw new ArgumentNullException(nameof(options)); + + var packet = new MqttSubscribePacket + { + Properties = new MqttSubscribePacketProperties() + }; + + packet.TopicFilters.AddRange(options.TopicFilters); + packet.Properties.UserProperties.AddRange(options.UserProperties); + + return packet; + } + + public MqttUnsubscribePacket CreateUnsubscribePacket(MqttClientUnsubscribeOptions options) + { + if (options == null) throw new ArgumentNullException(nameof(options)); + + var packet = new MqttUnsubscribePacket + { + Properties = new MqttUnsubscribePacketProperties() + }; + + packet.TopicFilters.AddRange(options.TopicFilters); + packet.Properties.UserProperties.AddRange(options.UserProperties); + + return packet; + } } } diff --git a/Source/MQTTnet/IApplicationMessageReceiver.cs b/Source/MQTTnet/IApplicationMessageReceiver.cs index b315067..806bb75 100644 --- a/Source/MQTTnet/IApplicationMessageReceiver.cs +++ b/Source/MQTTnet/IApplicationMessageReceiver.cs @@ -1,9 +1,13 @@ using System; +using MQTTnet.Client.Receiving; namespace MQTTnet { public interface IApplicationMessageReceiver { + IMqttApplicationMessageHandler ReceivedApplicationMessageHandler { get; set; } + + [Obsolete("Use _ReceivedApplicationMessageHandler_ instead.")] event EventHandler ApplicationMessageReceived; } } diff --git a/Source/MQTTnet/Internal/ExceptionHelper.cs b/Source/MQTTnet/Internal/ExceptionHelper.cs deleted file mode 100644 index 5bc8e43..0000000 --- a/Source/MQTTnet/Internal/ExceptionHelper.cs +++ /dev/null @@ -1,20 +0,0 @@ -using MQTTnet.Exceptions; - -namespace MQTTnet.Internal -{ - public static class ExceptionHelper - { - public static void ThrowGracefulSocketClose() - { - throw new MqttCommunicationClosedGracefullyException(); - } - - public static void ThrowIfGracefulSocketClose(int readBytesCount) - { - if (readBytesCount <= 0) - { - throw new MqttCommunicationClosedGracefullyException(); - } - } - } -} diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index 91e455f..d181925 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -63,20 +63,8 @@ - - - - - - - - - C:\Program Files\dotnet\sdk\NuGetFallbackFolder\microsoft.netcore.app\2.1.0\ref\netcoreapp2.1\System.Net.Requests.dll - - - \ No newline at end of file diff --git a/Source/MQTTnet/MessageStream/MqttMessageStream.cs b/Source/MQTTnet/MessageStream/MqttMessageStream.cs new file mode 100644 index 0000000..ef88147 --- /dev/null +++ b/Source/MQTTnet/MessageStream/MqttMessageStream.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; + +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Internal; +using MQTTnet.Packets; + +namespace MQTTnet.MessageStream +{ + public class MqttMessageStream + { + private readonly LinkedList _queue = new LinkedList(); + private readonly LinkedList> _waitHandles = new LinkedList>(); + + private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1); + + private readonly BlockingQueue _packets = new BlockingQueue(); + + public void Enqueue(MqttBasePacket packet) + { + if (packet == null) throw new ArgumentNullException(nameof(packet)); + + + + lock (_queue) + { + //_queue.AddLast(packet); + + _packets.Enqueue(packet); + } + + //lock (_queue) + //{ + // _queue.AddLast(packet); + + // foreach (var waitHandle in _waitHandles) + // { + // waitHandle.TrySetResult(true); + // } + + // _waitHandles.Clear(); + //} + } + + public Task TakeAsync(CancellationToken cancellationToken) + { + lock (_packets) + { + var packet = _packets.Dequeue(); + return Task.FromResult(packet); + } + + + //while (!cancellationToken.IsCancellationRequested) + //{ + // TaskCompletionSource waitHandle; + // lock (_queue) + // { + // if (_queue.Count > 0) + // { + // var node = _queue.First; + // _queue.RemoveFirst(); + + // return node.Value; + // } + + // waitHandle = new TaskCompletionSource(); + // _waitHandles.Add(waitHandle); + // } + + // await waitHandle.Task; + //} + + //return null; + } + } +} diff --git a/Source/MQTTnet/Client/PacketDispatcher/IMqttPacketAwaiter.cs b/Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs similarity index 55% rename from Source/MQTTnet/Client/PacketDispatcher/IMqttPacketAwaiter.cs rename to Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs index 4887b1e..786dc97 100644 --- a/Source/MQTTnet/Client/PacketDispatcher/IMqttPacketAwaiter.cs +++ b/Source/MQTTnet/PacketDispatcher/IMqttPacketAwaiter.cs @@ -1,12 +1,14 @@ using System; using MQTTnet.Packets; -namespace MQTTnet.Client.PacketDispatcher +namespace MQTTnet.PacketDispatcher { - public interface IMqttPacketAwaiter + public interface IMqttPacketAwaiter : IDisposable { void Complete(MqttBasePacket packet); void Fail(Exception exception); + + void Cancel(); } } \ No newline at end of file diff --git a/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs new file mode 100644 index 0000000..c30a6f4 --- /dev/null +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketAwaiter.cs @@ -0,0 +1,59 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Packets; + +namespace MQTTnet.PacketDispatcher +{ + public sealed class MqttPacketAwaiter : IMqttPacketAwaiter where TPacket : MqttBasePacket + { + private readonly TaskCompletionSource _taskCompletionSource = new TaskCompletionSource(); + private readonly ushort? _packetIdentifier; + private readonly MqttPacketDispatcher _owningPacketDispatcher; + + public MqttPacketAwaiter(ushort? packetIdentifier, MqttPacketDispatcher owningPacketDispatcher) + { + _packetIdentifier = packetIdentifier; + _owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher)); + } + + public async Task WaitOneAsync(TimeSpan timeout) + { + using (var timeoutToken = new CancellationTokenSource(timeout)) + { + timeoutToken.Token.Register(() => _taskCompletionSource.TrySetCanceled()); + + var packet = await _taskCompletionSource.Task.ConfigureAwait(false); + return (TPacket)packet; + } + } + + public void Complete(MqttBasePacket packet) + { + if (packet == null) throw new ArgumentNullException(nameof(packet)); + + // To prevent deadlocks it is required to call the _TrySetResult_ method + // from a new thread because the awaiting code will not(!) be executed in + // a new thread automatically (due to await). Furthermore _this_ thread will + // do it. But _this_ thread is also reading incoming packets -> deadlock. + Task.Run(() => _taskCompletionSource.TrySetResult(packet)); + } + + public void Fail(Exception exception) + { + if (exception == null) throw new ArgumentNullException(nameof(exception)); + + Task.Run(() => _taskCompletionSource.TrySetException(exception)); + } + + public void Cancel() + { + Task.Run(() => _taskCompletionSource.TrySetCanceled()); + } + + public void Dispose() + { + _owningPacketDispatcher.RemovePacketAwaiter(_packetIdentifier); + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/PacketDispatcher/MqttPacketDispatcher.cs b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs similarity index 53% rename from Source/MQTTnet/Client/PacketDispatcher/MqttPacketDispatcher.cs rename to Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs index 8c6a769..b4a26c2 100644 --- a/Source/MQTTnet/Client/PacketDispatcher/MqttPacketDispatcher.cs +++ b/Source/MQTTnet/PacketDispatcher/MqttPacketDispatcher.cs @@ -1,22 +1,24 @@ using System; using System.Collections.Concurrent; -using System.Threading.Tasks; +using System.Threading; using MQTTnet.Packets; -namespace MQTTnet.Client.PacketDispatcher +namespace MQTTnet.PacketDispatcher { public class MqttPacketDispatcher { - private readonly ConcurrentDictionary, IMqttPacketAwaiter> _awaiters = new ConcurrentDictionary, IMqttPacketAwaiter>(); + private readonly ConcurrentDictionary, IMqttPacketAwaiter> _packetAwaiters = new ConcurrentDictionary, IMqttPacketAwaiter>(); + + private BlockingCollection _inboundPackagesQueue = new BlockingCollection(); public void Dispatch(Exception exception) { - foreach (var awaiter in _awaiters) + foreach (var awaiter in _packetAwaiters) { - Task.Run(() => awaiter.Value.Fail(exception)); // Task.Run fixes a dead lock. Without this the client only receives one message. + awaiter.Value.Fail(exception); } - _awaiters.Clear(); + _packetAwaiters.Clear(); } public void Dispatch(MqttBasePacket packet) @@ -32,31 +34,56 @@ namespace MQTTnet.Client.PacketDispatcher var type = packet.GetType(); var key = new Tuple(identifier, type); - if (_awaiters.TryRemove(key, out var awaiter)) + if (_packetAwaiters.TryRemove(key, out var awaiter)) { - Task.Run(() => awaiter.Complete(packet)); // Task.Run fixes a dead lock. Without this the client only receives one message. + awaiter.Complete(packet); return; } - throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched."); + lock (_inboundPackagesQueue) + { + _inboundPackagesQueue.Add(packet); + } + } + + public MqttBasePacket Take(CancellationToken cancellationToken) + { + BlockingCollection inboundPackagesQueue; + lock (_inboundPackagesQueue) + { + inboundPackagesQueue = _inboundPackagesQueue; + } + + return inboundPackagesQueue.Take(cancellationToken); } public void Reset() { - _awaiters.Clear(); + foreach (var awaiter in _packetAwaiters) + { + awaiter.Value.Cancel(); + } + + lock (_inboundPackagesQueue) + { + _inboundPackagesQueue?.Dispose(); + _inboundPackagesQueue = new BlockingCollection(); + } + + _packetAwaiters.Clear(); } public MqttPacketAwaiter AddPacketAwaiter(ushort? identifier) where TResponsePacket : MqttBasePacket { - var awaiter = new MqttPacketAwaiter(); - if (!identifier.HasValue) { identifier = 0; } + var awaiter = new MqttPacketAwaiter(identifier, this); + var key = new Tuple(identifier.Value, typeof(TResponsePacket)); - if (!_awaiters.TryAdd(key, awaiter)) + if (!_packetAwaiters.TryAdd(key, awaiter)) { throw new InvalidOperationException($"The packet dispatcher already has an awaiter for packet of type '{key.Item2.Name}' with identifier {key.Item1}."); } @@ -71,8 +98,8 @@ namespace MQTTnet.Client.PacketDispatcher identifier = 0; } - var key = new Tuple(identifier ?? 0, typeof(TResponsePacket)); - _awaiters.TryRemove(key, out _); + var key = new Tuple(identifier.Value, typeof(TResponsePacket)); + _packetAwaiters.TryRemove(key, out _); } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 688d094..ad6fdfa 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -7,6 +7,8 @@ using MQTTnet.Adapter; using MQTTnet.Client; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; +using MQTTnet.MessageStream; +using MQTTnet.PacketDispatcher; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -15,6 +17,7 @@ namespace MQTTnet.Server public class MqttClientSession : IMqttClientSession { private readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); + private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher(); private readonly MqttRetainedMessagesManager _retainedMessagesManager; private readonly MqttServerEventDispatcher _eventDispatcher; @@ -56,7 +59,7 @@ namespace MQTTnet.Server _keepAliveMonitor = new MqttClientKeepAliveMonitor(this, _logger); _subscriptionsManager = new MqttClientSubscriptionsManager(clientId, _options, eventDispatcher); - _pendingMessagesQueue = new MqttClientSessionPendingMessagesQueue(_options, this, _logger); + _pendingMessagesQueue = new MqttClientSessionPendingMessagesQueue(_options, this, _packetDispatcher, _logger); } public string ClientId { get; } @@ -144,13 +147,14 @@ namespace MQTTnet.Server } var publishPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreatePublishPacket(applicationMessage); + publishPacket.QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel; // Set the retain flag to true according to [MQTT-3.3.1-8] and [MQTT-3.3.1-9]. publishPacket.Retain = isRetainedApplicationMessage; if (publishPacket.QualityOfServiceLevel > 0) { - publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); + publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); } if (_options.ClientMessageQueueInterceptor != null) @@ -178,6 +182,8 @@ namespace MQTTnet.Server _pendingMessagesQueue.Enqueue(publishPacket); } + + private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) { if (channelAdapter == null) throw new ArgumentNullException(nameof(channelAdapter)); @@ -212,13 +218,41 @@ namespace MQTTnet.Server _isCleanSession = false; + Task.Run(async () => + { + while (!cancellationTokenSource.IsCancellationRequested) + { + var packet = _packetDispatcher.Take(cancellationTokenSource.Token); + await ProcessReceivedPacketAsync(packet, cancellationTokenSource.Token).ConfigureAwait(false); + } + }, cancellationTokenSource.Token); + + Task.Run(async () => + { + while (!cancellationTokenSource.IsCancellationRequested) + { + try + { + var packet = await _outboundMessageStream.TakeAsync(cancellationTokenSource.Token); + await channelAdapter.SendPacketAsync(packet, cancellationTokenSource.Token); + } + catch (Exception e) + { + _logger.Error(e, "sdfsdf"); + await StopAsync(MqttClientDisconnectType.NotClean); + + } + + } + },cancellationTokenSource.Token); + while (!cancellationTokenSource.IsCancellationRequested) { var packet = await channelAdapter.ReceivePacketAsync(TimeSpan.Zero, cancellationTokenSource.Token).ConfigureAwait(false); if (packet != null) { _keepAliveMonitor.PacketReceived(packet); - await ProcessReceivedPacketAsync(channelAdapter, packet, cancellationTokenSource.Token).ConfigureAwait(false); + _packetDispatcher.Dispatch(packet); } } } @@ -229,17 +263,7 @@ namespace MQTTnet.Server { if (exception is MqttCommunicationException) { - if (exception is MqttCommunicationClosedGracefullyException) - { - _logger.Verbose("Client '{0}': Connection closed gracefully.", ClientId); - - StopInternal(MqttClientDisconnectType.Clean); - return; - } - else - { - _logger.Warning(exception, "Client '{0}': Communication exception while receiving client packets.", ClientId); - } + _logger.Warning(exception, "Client '{0}': Communication exception while receiving client packets.", ClientId); } else { @@ -256,18 +280,18 @@ namespace MQTTnet.Server } _willMessage = null; - + _channelAdapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; _channelAdapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; _channelAdapter = null; - + _logger.Info("Client '{0}': Session stopped.", ClientId); _eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect); _workerTask = null; } } - + private void StopInternal(MqttClientDisconnectType type) { var cts = _cancellationTokenSource; @@ -278,55 +302,33 @@ namespace MQTTnet.Server _wasCleanDisconnect = type == MqttClientDisconnectType.Clean; _cancellationTokenSource?.Cancel(false); + _packetDispatcher.Reset(); } - private Task ProcessReceivedPacketAsync(IMqttChannelAdapter channelAdapter, MqttBasePacket packet, CancellationToken cancellationToken) + private readonly MqttMessageStream _outboundMessageStream = new MqttMessageStream(); + + private Task ProcessReceivedPacketAsync(MqttBasePacket packet, CancellationToken cancellationToken) { if (packet is MqttPublishPacket publishPacket) { - return HandleIncomingPublishPacketAsync(channelAdapter, publishPacket, cancellationToken); + return HandleIncomingPublishPacketAsync(publishPacket, cancellationToken); } if (packet is MqttPingReqPacket) { - return channelAdapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); - } - - if (packet is MqttPubRelPacket pubRelPacket) - { - var responsePacket = new MqttPubCompPacket - { - PacketIdentifier = pubRelPacket.PacketIdentifier, - ReasonCode = MqttPubCompReasonCode.Success - }; - - return channelAdapter.SendPacketAsync(responsePacket, cancellationToken); - } - - if (packet is MqttPubRecPacket pubRecPacket) - { - var responsePacket = new MqttPubRelPacket - { - PacketIdentifier = pubRecPacket.PacketIdentifier, - ReasonCode = MqttPubRelReasonCode.Success - }; - - return channelAdapter.SendPacketAsync(responsePacket, cancellationToken); - } - - if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) - { + //return channelAdapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); + _outboundMessageStream.Enqueue(new MqttPingRespPacket()); return Task.FromResult(0); } if (packet is MqttSubscribePacket subscribePacket) { - return HandleIncomingSubscribePacketAsync(channelAdapter, subscribePacket, cancellationToken); + return HandleIncomingSubscribePacketAsync(subscribePacket, cancellationToken); } if (packet is MqttUnsubscribePacket unsubscribePacket) { - return HandleIncomingUnsubscribePacketAsync(channelAdapter, unsubscribePacket, cancellationToken); + return HandleIncomingUnsubscribePacketAsync(unsubscribePacket, cancellationToken); } if (packet is MqttDisconnectPacket) @@ -335,6 +337,18 @@ namespace MQTTnet.Server return Task.FromResult(0); } + //if (packet is MqttAuthPacket || + // packet is MqttSubAckPacket || + // packet is MqttUnsubAckPacket || + // packet is MqttPubAckPacket || + // packet is MqttPubCompPacket || + // packet is MqttPubRecPacket || + // packet is MqttPubRelPacket) + //{ + // _packetDispatcher.TryDispatch(packet); + // return Task.FromResult(0); + //} + _logger.Warning(null, "Client '{0}': Received invalid packet ({1}). Closing connection.", ClientId, packet); StopInternal(MqttClientDisconnectType.NotClean); @@ -350,10 +364,14 @@ namespace MQTTnet.Server } } - private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) + private async Task HandleIncomingSubscribePacketAsync(MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); - await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); + + _outboundMessageStream.Enqueue(subscribeResult.ResponsePacket); + //await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); + + // TODO: Add "WaitForDelivery". if (subscribeResult.CloseConnection) { @@ -364,13 +382,17 @@ namespace MQTTnet.Server await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); } - private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) + private Task HandleIncomingUnsubscribePacketAsync(MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); - return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); + + _outboundMessageStream.Enqueue(unsubscribeResult); + + //return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); + return Task.FromResult(0); } - private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) + private Task HandleIncomingPublishPacketAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { Interlocked.Increment(ref _receivedMessagesCount); @@ -382,11 +404,11 @@ namespace MQTTnet.Server } case MqttQualityOfServiceLevel.AtLeastOnce: { - return HandleIncomingPublishPacketWithQoS1Async(adapter, publishPacket, cancellationToken); + return HandleIncomingPublishPacketWithQoS1Async(publishPacket, cancellationToken); } case MqttQualityOfServiceLevel.ExactlyOnce: { - return HandleIncomingPublishPacketWithQoS2Async(adapter, publishPacket, cancellationToken); + return HandleIncomingPublishPacketWithQoS2Async(publishPacket, cancellationToken); } default: { @@ -405,7 +427,6 @@ namespace MQTTnet.Server } private Task HandleIncomingPublishPacketWithQoS1Async( - IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { @@ -419,24 +440,32 @@ namespace MQTTnet.Server ReasonCode = MqttPubAckReasonCode.Success }; - return adapter.SendPacketAsync(response, cancellationToken); + _outboundMessageStream.Enqueue(response); + + //return adapter.SendPacketAsync(response, cancellationToken); + return Task.FromResult(0); } - private Task HandleIncomingPublishPacketWithQoS2Async( - IMqttChannelAdapter adapter, + private async Task HandleIncomingPublishPacketWithQoS2Async( MqttPublishPacket publishPacket, CancellationToken cancellationToken) { // QoS 2 is implement as method "B" (4.3.3 QoS 2: Exactly once delivery) _sessionsManager.EnqueueApplicationMessage(this, _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); - var response = new MqttPubRecPacket + using (var pubRelPacketAwaiter = _packetDispatcher.AddPacketAwaiter(publishPacket.PacketIdentifier)) { - PacketIdentifier = publishPacket.PacketIdentifier, - ReasonCode = MqttPubRecReasonCode.Success - }; + var pubRecPacket = new MqttPubRecPacket + { + PacketIdentifier = publishPacket.PacketIdentifier, + ReasonCode = MqttPubRecReasonCode.Success + }; + + //await adapter.SendPacketAsync(pubRecPacket, cancellationToken).ConfigureAwait(false); + _outboundMessageStream.Enqueue(pubRecPacket); - return adapter.SendPacketAsync(response, cancellationToken); + await pubRelPacketAwaiter.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + } } private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) diff --git a/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs b/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs index 60d5869..1fd19e8 100644 --- a/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs +++ b/Source/MQTTnet/Server/MqttClientSessionPendingMessagesQueue.cs @@ -6,6 +6,7 @@ using MQTTnet.Adapter; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Internal; +using MQTTnet.PacketDispatcher; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -13,20 +14,26 @@ namespace MQTTnet.Server { public class MqttClientSessionPendingMessagesQueue : IDisposable { - private readonly Queue _queue = new Queue(); + private readonly Queue _queue = new Queue(); private readonly AsyncAutoResetEvent _queueLock = new AsyncAutoResetEvent(); private readonly IMqttServerOptions _options; private readonly MqttClientSession _clientSession; + private readonly MqttPacketDispatcher _packetDispatcher; private readonly IMqttNetChildLogger _logger; private long _sentPacketsCount; - public MqttClientSessionPendingMessagesQueue(IMqttServerOptions options, MqttClientSession clientSession, IMqttNetChildLogger logger) + public MqttClientSessionPendingMessagesQueue( + IMqttServerOptions options, + MqttClientSession clientSession, + MqttPacketDispatcher packetDispatcher, + IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _options = options ?? throw new ArgumentNullException(nameof(options)); _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); + _packetDispatcher = packetDispatcher ?? throw new ArgumentNullException(nameof(packetDispatcher)); _logger = logger.CreateChildLogger(nameof(MqttClientSessionPendingMessagesQueue)); } @@ -115,7 +122,7 @@ namespace MQTTnet.Server private async Task TrySendNextQueuedPacketAsync(IMqttChannelAdapter adapter, CancellationToken cancellationToken) { - MqttBasePacket packet = null; + MqttPublishPacket packet = null; try { if (cancellationToken.IsCancellationRequested) @@ -137,7 +144,34 @@ namespace MQTTnet.Server return; } - await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); + if (packet.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtMostOnce) + { + await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); + } + else if (packet.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce) + { + var awaiter = _packetDispatcher.AddPacketAwaiter(packet.PacketIdentifier); + await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); + await awaiter.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + } + else if (packet.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) + { + var awaiter1 = _packetDispatcher.AddPacketAwaiter(packet.PacketIdentifier); + var awaiter2 = _packetDispatcher.AddPacketAwaiter(packet.PacketIdentifier); + try + { + await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); + await awaiter1.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + + await adapter.SendPacketAsync(new MqttPubRelPacket { PacketIdentifier = packet.PacketIdentifier }, cancellationToken).ConfigureAwait(false); + await awaiter2.WaitOneAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false); + } + finally + { + _packetDispatcher.RemovePacketAwaiter(packet.PacketIdentifier); + _packetDispatcher.RemovePacketAwaiter(packet.PacketIdentifier); + } + } _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); @@ -153,7 +187,7 @@ namespace MQTTnet.Server { _logger.Warning(exception, "Sending publish packet failed: Communication exception (ClientId: {0}).", _clientSession.ClientId); } - else if (exception is OperationCanceledException) + else if (exception is OperationCanceledException && cancellationToken.IsCancellationRequested) { } else @@ -161,14 +195,11 @@ namespace MQTTnet.Server _logger.Error(exception, "Sending publish packet failed (ClientId: {0}).", _clientSession.ClientId); } - if (packet is MqttPublishPacket publishPacket) + if (packet?.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) { - if (publishPacket.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce) - { - publishPacket.Dup = true; + packet.Dup = true; - Enqueue(publishPacket); - } + Enqueue(packet); } if (!cancellationToken.IsCancellationRequested) diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index ffba8fb..d2145bd 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -5,6 +5,7 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Client.Publishing; +using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; namespace MQTTnet.Server @@ -31,7 +32,18 @@ namespace MQTTnet.Server _eventDispatcher.ClientDisconnected += (s, e) => ClientDisconnected?.Invoke(s, e); _eventDispatcher.ClientSubscribedTopic += (s, e) => ClientSubscribedTopic?.Invoke(s, e); _eventDispatcher.ClientUnsubscribedTopic += (s, e) => ClientUnsubscribedTopic?.Invoke(s, e); - _eventDispatcher.ApplicationMessageReceived += (s, e) => ApplicationMessageReceived?.Invoke(s, e); + _eventDispatcher.ApplicationMessageReceived += async (s, e) => + { + // TODO: Migrate EventDispatcher to proper handlers and no events anymore. + ApplicationMessageReceived?.Invoke(s, e); + + var handler = ReceivedApplicationMessageHandler; + if (handler != null) + { + await handler.HandleApplicationMessageAsync( + new MqttApplicationMessageHandlerContext(e.ClientId, e.ApplicationMessage)).ConfigureAwait(false); + } + }; } public event EventHandler Started; @@ -42,6 +54,7 @@ namespace MQTTnet.Server public event EventHandler ClientSubscribedTopic; public event EventHandler ClientUnsubscribedTopic; + public IMqttApplicationMessageHandler ReceivedApplicationMessageHandler { get; set; } public event EventHandler ApplicationMessageReceived; public IMqttServerOptions Options { get; private set; } diff --git a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs index b1797ef..346626e 100644 --- a/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Tests/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -52,7 +52,7 @@ namespace MQTTnet.Benchmarks { channel.Reset(); - var header = reader.ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult(); + var header = reader.ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; var receivedPacket = new ReceivedMqttPacket( header.Flags, diff --git a/Tests/MQTTnet.Core.Tests/ManagedMqttClientTests.cs b/Tests/MQTTnet.Core.Tests/ManagedMqttClientTests.cs index 559a87c..1f19521 100644 --- a/Tests/MQTTnet.Core.Tests/ManagedMqttClientTests.cs +++ b/Tests/MQTTnet.Core.Tests/ManagedMqttClientTests.cs @@ -13,26 +13,32 @@ namespace MQTTnet.Tests { var factory = new MqttFactory(); var managedClient = factory.CreateManagedMqttClient(); + try + { + var clientOptions = new ManagedMqttClientOptionsBuilder() + .WithMaxPendingMessages(5) + .WithPendingMessagesOverflowStrategy(MqttPendingMessagesOverflowStrategy.DropNewMessage); - var clientOptions = new ManagedMqttClientOptionsBuilder() - .WithMaxPendingMessages(5) - .WithPendingMessagesOverflowStrategy(MqttPendingMessagesOverflowStrategy.DropNewMessage); + clientOptions.WithClientOptions(o => o.WithTcpServer("localhost")); - clientOptions.WithClientOptions(o => o.WithTcpServer("localhost")); + await managedClient.StartAsync(clientOptions.Build()); - await managedClient.StartAsync(clientOptions.Build()); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "1" }); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "2" }); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "3" }); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "4" }); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "5" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "1" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "2" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "3" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "4" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "5" }); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "6" }); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "7" }); + await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "8" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "6" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "7" }); - await managedClient.PublishAsync(new MqttApplicationMessage { Topic = "8" }); - - Assert.AreEqual(5, managedClient.PendingApplicationMessagesCount); + Assert.AreEqual(5, managedClient.PendingApplicationMessagesCount); + } + finally + { + await managedClient.StopAsync(); + } } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs index 6c70caf..9725fee 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs @@ -2,11 +2,14 @@ using System; using System.Collections.Generic; using System.Linq; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; using MQTTnet.Client.Options; +using MQTTnet.Client.Receiving; using MQTTnet.Exceptions; +using MQTTnet.Protocol; using MQTTnet.Server; namespace MQTTnet.Tests @@ -39,11 +42,100 @@ namespace MQTTnet.Tests Assert.IsInstanceOfType(ex.InnerException, typeof(SocketException)); } + [TestMethod] + public async Task Client_Preserve_Message_Order() + { + // The messages are sent in reverse or to ensure that the delay in the handler + // needs longer for the first messages and later messages may be processed earlier (if there + // is an issue). + const int MessagesCount = 50; + + using (var testSetup = new TestSetup()) + { + await testSetup.StartServerAsync(); + + var client1 = await testSetup.ConnectClientAsync(); + await client1.SubscribeAsync("x"); + + var receivedValues = new List(); + + async Task Handler1(MqttApplicationMessageHandlerContext context) + { + var value = int.Parse(context.ApplicationMessage.ConvertPayloadToString()); + await Task.Delay(value); + + lock (receivedValues) + { + receivedValues.Add(value); + } + } + + client1.UseReceivedApplicationMessageHandler(Handler1); + + var client2 = await testSetup.ConnectClientAsync(); + for (var i = MessagesCount; i > 0; i--) + { + await client2.PublishAsync("x", i.ToString()); + } + + await Task.Delay(5000); + + for (var i = MessagesCount; i > 0; i--) + { + Assert.AreEqual(i, receivedValues[MessagesCount - i]); + } + } + } + + [TestMethod] + public async Task Client_Send_Reply_For_Any_Received_Message() + { + using (var testSetup = new TestSetup()) + { + await testSetup.StartServerAsync(); + + var client1 = await testSetup.ConnectClientAsync(); + await client1.SubscribeAsync("request/+"); + + async Task Handler1(MqttApplicationMessageHandlerContext context) + { + await client1.PublishAsync($"reply/{context.ApplicationMessage.Topic}"); + } + + client1.UseReceivedApplicationMessageHandler(Handler1); + + var client2 = await testSetup.ConnectClientAsync(); + await client2.SubscribeAsync("reply/#"); + + var replies = new List(); + + void Handler2(MqttApplicationMessageHandlerContext context) + { + lock (replies) + { + replies.Add(context.ApplicationMessage.Topic); + } + } + + client2.UseReceivedApplicationMessageHandler((Action) Handler2); + + await Task.Delay(500); + + await client2.PublishAsync("request/a"); + await client2.PublishAsync("request/b"); + await client2.PublishAsync("request/c"); + + await Task.Delay(500); + + Assert.AreEqual("reply/request/a,reply/request/b,reply/request/c", string.Join("," , replies)); + } + } + [TestMethod] public async Task Client_Publish() { var server = new MqttFactory().CreateMqttServer(); - + try { var receivedMessages = new List(); @@ -125,39 +217,85 @@ namespace MQTTnet.Tests } } -//#if DEBUG -// [TestMethod] -// public async Task Client_Cleanup_On_Authentification_Fails() -// { -// var channel = new TestMqttCommunicationAdapter(); -// var channel2 = new TestMqttCommunicationAdapter(); -// channel.Partner = channel2; -// channel2.Partner = channel; - -// Task.Run(async () => { -// var connect = await channel2.ReceivePacketAsync(TimeSpan.Zero, CancellationToken.None); -// await channel2.SendPacketAsync(new MqttConnAckPacket -// { -// ConnectReturnCode = Protocol.MqttConnectReturnCode.ConnectionRefusedNotAuthorized -// }, CancellationToken.None); -// }); - -// var fake = new TestMqttCommunicationAdapterFactory(channel); - -// var client = new MqttClient(fake, new MqttNetLogger()); - -// try -// { -// await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("any-server").Build()); -// } -// catch (Exception ex) -// { -// Assert.IsInstanceOfType(ex, typeof(MqttConnectingFailedException)); -// } - -// Assert.IsTrue(client._packetReceiverTask == null || client._packetReceiverTask.IsCompleted, "receive loop not completed"); -// Assert.IsTrue(client._keepAliveMessageSenderTask == null || client._keepAliveMessageSenderTask.IsCompleted, "keepalive loop not completed"); -// } -//#endif + [TestMethod] + public async Task Client_Exception_In_Application_Message_Handler() + { + using (var testSetup = new TestSetup()) + { + testSetup.IgnoreClientLogErrors = true; + testSetup.IgnoreServerLogErrors = true; + + await testSetup.StartServerAsync( + new MqttServerOptionsBuilder() + .WithPersistentSessions() + .WithDefaultCommunicationTimeout(TimeSpan.FromMilliseconds(50))); + + var client1 = await testSetup.ConnectClientAsync(new MqttClientOptionsBuilder() + .WithCleanSession(false)); + + await client1.SubscribeAsync("x", MqttQualityOfServiceLevel.AtLeastOnce); + + var retries = 0; + + async Task Handler1(MqttApplicationMessageHandlerContext context) + { + retries++; + + await Task.Delay(50); + throw new Exception("Broken!"); + } + + client1.UseReceivedApplicationMessageHandler(Handler1); + + var client2 = await testSetup.ConnectClientAsync(); + await client2.PublishAsync("x"); + + await Task.Delay(1000); + + // The server should disconnect clients which are not responding. + Assert.IsFalse(client1.IsConnected); + + await client1.ReconnectAsync().ConfigureAwait(false); + + await Task.Delay(1000); + + Assert.AreEqual(2, retries); + } + } + + //#if DEBUG + // [TestMethod] + // public async Task Client_Cleanup_On_Authentification_Fails() + // { + // var channel = new TestMqttCommunicationAdapter(); + // var channel2 = new TestMqttCommunicationAdapter(); + // channel.Partner = channel2; + // channel2.Partner = channel; + + // Task.Run(async () => { + // var connect = await channel2.ReceivePacketAsync(TimeSpan.Zero, CancellationToken.None); + // await channel2.SendPacketAsync(new MqttConnAckPacket + // { + // ConnectReturnCode = Protocol.MqttConnectReturnCode.ConnectionRefusedNotAuthorized + // }, CancellationToken.None); + // }); + + // var fake = new TestMqttCommunicationAdapterFactory(channel); + + // var client = new MqttClient(fake, new MqttNetLogger()); + + // try + // { + // await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("any-server").Build()); + // } + // catch (Exception ex) + // { + // Assert.IsInstanceOfType(ex, typeof(MqttConnectingFailedException)); + // } + + // Assert.IsTrue(client._packetReceiverTask == null || client._packetReceiverTask.IsCompleted, "receive loop not completed"); + // Assert.IsTrue(client._keepAliveMessageSenderTask == null || client._keepAliveMessageSenderTask.IsCompleted, "keepalive loop not completed"); + // } + //#endif } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketIdentifierProviderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketIdentifierProviderTests.cs index e5d0cd7..e2f733a 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketIdentifierProviderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketIdentifierProviderTests.cs @@ -10,10 +10,10 @@ namespace MQTTnet.Tests public void Reset() { var p = new MqttPacketIdentifierProvider(); - Assert.AreEqual(1, p.GetNewPacketIdentifier()); - Assert.AreEqual(2, p.GetNewPacketIdentifier()); + Assert.AreEqual(1, p.GetNextPacketIdentifier()); + Assert.AreEqual(2, p.GetNextPacketIdentifier()); p.Reset(); - Assert.AreEqual(1, p.GetNewPacketIdentifier()); + Assert.AreEqual(1, p.GetNextPacketIdentifier()); } [TestMethod] @@ -23,10 +23,10 @@ namespace MQTTnet.Tests for (ushort i = 0; i < ushort.MaxValue; i++) { - Assert.AreEqual(i + 1, p.GetNewPacketIdentifier()); + Assert.AreEqual(i + 1, p.GetNextPacketIdentifier()); } - Assert.AreEqual(1, p.GetNewPacketIdentifier()); + Assert.AreEqual(1, p.GetNextPacketIdentifier()); } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index 215dd0a..4ae204b 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -1,7 +1,7 @@ using System.IO; using System.Threading; +using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Exceptions; using MQTTnet.Formatter; using MQTTnet.Internal; @@ -11,12 +11,13 @@ namespace MQTTnet.Tests public class MqttPacketReaderTests { [TestMethod] - [ExpectedException(typeof(MqttCommunicationClosedGracefullyException))] - public void MqttPacketReader_EmptyStream() + public async Task MqttPacketReader_EmptyStream() { var fixedHeader = new byte[2]; var reader = new MqttPacketReader(new TestMqttChannel(new MemoryStream())); - reader.ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult(); + var readResult = await reader.ReadFixedHeaderAsync(fixedHeader, CancellationToken.None); + + Assert.IsTrue(readResult.ConnectionClosed); } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 2b61274..47ef0fd 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -182,7 +182,7 @@ namespace MQTTnet.Tests var header = new MqttPacketReader(testChannel).ReadFixedHeaderAsync( new byte[2], - CancellationToken.None).GetAwaiter().GetResult(); + CancellationToken.None).GetAwaiter().GetResult().FixedHeader; var eof = buffer.Offset + buffer.Count; @@ -550,7 +550,7 @@ namespace MQTTnet.Tests { var channel = new TestMqttChannel(headerStream); var fixedHeader = new byte[2]; - var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult(); + var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) { @@ -586,7 +586,7 @@ namespace MQTTnet.Tests var channel = new TestMqttChannel(headerStream); var fixedHeader = new byte[2]; - var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult(); + var header = new MqttPacketReader(channel).ReadFixedHeaderAsync(fixedHeader, CancellationToken.None).GetAwaiter().GetResult().FixedHeader; using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, (int)header.RemainingLength)) { diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index ed6faa9..fc66c2d 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net.Sockets; @@ -19,36 +20,36 @@ namespace MQTTnet.Tests public class MqttServerTests { [TestMethod] - public void MqttServer_PublishSimple_AtMostOnce() + public async Task MqttServer_PublishSimple_AtMostOnce() { - TestPublishAsync( + await TestPublishAsync( "A/B/C", MqttQualityOfServiceLevel.AtMostOnce, "A/B/C", MqttQualityOfServiceLevel.AtMostOnce, - 1).Wait(); + 1); } [TestMethod] - public void MqttServer_PublishSimple_AtLeastOnce() + public async Task MqttServer_PublishSimple_AtLeastOnce() { - TestPublishAsync( + await TestPublishAsync( "A/B/C", MqttQualityOfServiceLevel.AtLeastOnce, "A/B/C", MqttQualityOfServiceLevel.AtLeastOnce, - 1).Wait(); + 1); } [TestMethod] - public void MqttServer_PublishSimple_ExactlyOnce() + public async Task MqttServer_PublishSimple_ExactlyOnce() { - TestPublishAsync( + await TestPublishAsync( "A/B/C", MqttQualityOfServiceLevel.ExactlyOnce, "A/B/C", MqttQualityOfServiceLevel.ExactlyOnce, - 1).Wait(); + 1); } [TestMethod] @@ -143,32 +144,27 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_Publish() + public async Task MqttServer_Publish_From_Server() { - var serverAdapter = new TestMqttServerAdapter(); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - var receivedMessagesCount = 0; - - try + using (var testSetup = new TestSetup()) { - await s.StartAsync(new MqttServerOptions()); + var server = await testSetup.StartServerAsync(); - var c1 = await serverAdapter.ConnectTestClient("c1"); + var receivedMessagesCount = 0; - c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; + var client = await testSetup.ConnectClientAsync(); + client.UseReceivedApplicationMessageHandler(c => Interlocked.Increment(ref receivedMessagesCount)); var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); - await c1.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); + await client.SubscribeAsync(new TopicFilter { Topic = "a", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - await s.PublishAsync(message); - await Task.Delay(500); - } - finally - { - await s.StopAsync(); - } + await server.PublishAsync(message); - Assert.AreEqual(1, receivedMessagesCount); + await Task.Delay(1000); + await server.StopAsync(); + + Assert.AreEqual(1, receivedMessagesCount); + } } [TestMethod] @@ -438,40 +434,43 @@ namespace MQTTnet.Tests [TestMethod] public async Task MqttServer_Lots_Of_Retained_Messages() { - const int ClientCount = 100; + const int ClientCount = 25; - var server = new MqttFactory().CreateMqttServer(); - try + using (var testSetup = new TestSetup()) { - await server.StartAsync(new MqttServerOptionsBuilder().Build()); - - Parallel.For( - 0, - ClientCount, - new ParallelOptions { MaxDegreeOfParallelism = 10 }, - i => + var server = await testSetup.StartServerAsync(); + + var tasks = new ConcurrentBag(); + for (var i = 0; i < ClientCount; i++) { - using (var client = new MqttFactory().CreateMqttClient()) + var clientId = i; + tasks.Add(Task.Run(async () => { - client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost").Build()) - .GetAwaiter().GetResult(); - - for (var j = 0; j < 10; j++) + try { - // Clear retained message. - client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + i) - .WithPayload(new byte[0]).WithRetainFlag().Build()).GetAwaiter().GetResult(); + using (var client = await testSetup.ConnectClientAsync()) + { + // Clear retained message. + await client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + clientId) + .WithPayload(new byte[0]).WithRetainFlag().Build()); - // Set retained message. - client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + i) - .WithPayload("value" + j).WithRetainFlag().Build()).GetAwaiter().GetResult(); - } + // Set retained message. + await client.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("r" + clientId) + .WithPayload("value").WithRetainFlag().Build()); - Thread.Sleep(100); + await Task.Delay(10); - client.DisconnectAsync().GetAwaiter().GetResult(); - } - }); + await client.DisconnectAsync(); + } + } + catch (Exception exception) + { + testSetup.TrackException(exception); + } + })); + } + + await Task.WhenAll(tasks.ToArray()); await Task.Delay(1000); @@ -484,10 +483,6 @@ namespace MQTTnet.Tests Assert.IsTrue(retainedMessages.Any(m => m.Topic == "r" + i)); } } - finally - { - await server.StopAsync(); - } } [TestMethod] @@ -986,6 +981,23 @@ namespace MQTTnet.Tests MqttQualityOfServiceLevel filterQualityOfServiceLevel, int expectedReceivedMessagesCount) { + //using (var testSetup = new TestSetup()) + //{ + // var server = await testSetup.StartServerAsync(new MqttServerOptionsBuilder()); + + // var clientOptions = new MqttClientOptionsBuilder(); + // var c1 = await testSetup.ConnectClientAsync(clientOptions); + // await Task.Delay(500); + // Assert.AreEqual(1, (await server.GetClientSessionsStatusAsync()).Count); + + // await c1.DisconnectAsync(); + // await Task.Delay(500); + + // Assert.AreEqual(0, (await server.GetClientSessionsStatusAsync()).Count); + //} + + + var s = new MqttFactory().CreateMqttServer(); var receivedMessagesCount = 0; @@ -994,13 +1006,15 @@ namespace MQTTnet.Tests await s.StartAsync(new MqttServerOptions()); var c1 = new MqttFactory().CreateMqttClient(); - c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++; - await c1.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost").Build()); + c1.UseReceivedApplicationMessageHandler(c => receivedMessagesCount++); + + await c1.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("broker.hivemq.com").Build()); await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic(topicFilter).WithQualityOfServiceLevel(filterQualityOfServiceLevel).Build()); var c2 = new MqttFactory().CreateMqttClient(); - await c2.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("localhost").Build()); + await c2.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("broker.hivemq.com").Build()); await c2.PublishAsync(builder => builder.WithTopic(topic).WithPayload(new byte[0]).WithQualityOfServiceLevel(qualityOfServiceLevel)); + await c2.DisconnectAsync().ConfigureAwait(false); await Task.Delay(500); await c1.UnsubscribeAsync(topicFilter); diff --git a/Tests/MQTTnet.Core.Tests/TestSetup.cs b/Tests/MQTTnet.Core.Tests/TestSetup.cs index 4ffc3ec..b1515ca 100644 --- a/Tests/MQTTnet.Core.Tests/TestSetup.cs +++ b/Tests/MQTTnet.Core.Tests/TestSetup.cs @@ -1,6 +1,6 @@ using System; using System.Collections.Generic; -using System.Threading; +using System.Linq; using System.Threading.Tasks; using MQTTnet.Client; using MQTTnet.Client.Options; @@ -16,10 +16,15 @@ namespace MQTTnet.Tests private readonly IMqttNetLogger _serverLogger = new MqttNetLogger("server"); private readonly IMqttNetLogger _clientLogger = new MqttNetLogger("client"); + private readonly List _serverErrors = new List(); + private readonly List _clientErrors = new List(); + + private readonly List _exceptions = new List(); + private IMqttServer _server; - private long _serverErrorsCount; - private long _clientErrorsCount; + public bool IgnoreClientLogErrors { get; set; } + public bool IgnoreServerLogErrors { get; set; } public TestSetup() { @@ -27,19 +32,30 @@ namespace MQTTnet.Tests { if (e.TraceMessage.Level == MqttNetLogLevel.Error) { - Interlocked.Increment(ref _serverErrorsCount); + lock (_serverErrors) + { + _serverErrors.Add(e.TraceMessage.ToString()); + } } }; _clientLogger.LogMessagePublished += (s, e) => { - if (e.TraceMessage.Level == MqttNetLogLevel.Error) + lock (_clientErrors) { - Interlocked.Increment(ref _clientErrorsCount); + if (e.TraceMessage.Level == MqttNetLogLevel.Error) + { + _clientErrors.Add(e.TraceMessage.ToString()); + } } }; } + public Task StartServerAsync() + { + return StartServerAsync(new MqttServerOptionsBuilder()); + } + public async Task StartServerAsync(MqttServerOptionsBuilder options) { if (_server != null) @@ -53,40 +69,63 @@ namespace MQTTnet.Tests return _server; } + public Task ConnectClientAsync() + { + return ConnectClientAsync(new MqttClientOptionsBuilder()); + } + public async Task ConnectClientAsync(MqttClientOptionsBuilder options) { var client = _mqttFactory.CreateMqttClient(_clientLogger); - _clients.Add(client); - await client.ConnectAsync(options.WithTcpServer("localhost", 1888).Build()); + _clients.Add(client); return client; } public void ThrowIfLogErrors() { - if (_serverErrorsCount > 0) + lock (_serverErrors) { - throw new Exception($"Server had {_serverErrorsCount} errors."); + if (!IgnoreServerLogErrors && _serverErrors.Count > 0) + { + throw new Exception($"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)})."); + } } - if (_clientErrorsCount > 0) + lock (_clientErrors) { - throw new Exception($"Client(s) had {_clientErrorsCount} errors."); + if (!IgnoreClientLogErrors && _clientErrors.Count > 0) + { + throw new Exception($"Client(s) had {_clientErrors.Count} errors (${string.Join(Environment.NewLine, _clientErrors)})."); + } } } public void Dispose() { - ThrowIfLogErrors(); - foreach (var mqttClient in _clients) { - mqttClient.DisconnectAsync().GetAwaiter().GetResult(); - mqttClient.Dispose(); + mqttClient?.DisconnectAsync().GetAwaiter().GetResult(); + mqttClient?.Dispose(); } _server.StopAsync().GetAwaiter().GetResult(); + + ThrowIfLogErrors(); + + if (_exceptions.Any()) + { + throw new Exception($"{_exceptions.Count} exceptions tracked."); + } + } + + public void TrackException(Exception exception) + { + lock (_exceptions) + { + _exceptions.Add(exception); + } } } }