diff --git a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs index 821a7d9..772f102 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/ManagedMqttClientStorageManager.cs @@ -30,7 +30,7 @@ namespace MQTTnet.Extensions.ManagedClient { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - using (await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + using (await _messagesLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { _messages.Add(applicationMessage); await SaveAsync().ConfigureAwait(false); @@ -41,7 +41,7 @@ namespace MQTTnet.Extensions.ManagedClient { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); - using (await _messagesLock.LockAsync(CancellationToken.None).ConfigureAwait(false)) + using (await _messagesLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) { var index = _messages.IndexOf(applicationMessage); if (index == -1) diff --git a/Source/MQTTnet/Formatter/MqttPacketReader.cs b/Source/MQTTnet/Formatter/MqttPacketReader.cs index dfb633a..eec643a 100644 --- a/Source/MQTTnet/Formatter/MqttPacketReader.cs +++ b/Source/MQTTnet/Formatter/MqttPacketReader.cs @@ -63,7 +63,7 @@ namespace MQTTnet.Formatter { var offset = 0; var multiplier = 128; - var value = (initialEncodedByte & 127); + var value = initialEncodedByte & 127; int encodedByte = initialEncodedByte; while ((encodedByte & 128) != 0) diff --git a/Source/MQTTnet/Formatter/MqttPacketWriter.cs b/Source/MQTTnet/Formatter/MqttPacketWriter.cs index bf417f7..da483bb 100644 --- a/Source/MQTTnet/Formatter/MqttPacketWriter.cs +++ b/Source/MQTTnet/Formatter/MqttPacketWriter.cs @@ -15,6 +15,7 @@ namespace MQTTnet.Formatter { public static int MaxBufferSize = 4096; + // TODO: Consider using the ArrayPool here together with FreeBuffer. private byte[] _buffer = new byte[128]; private int _offset; @@ -28,13 +29,18 @@ namespace MQTTnet.Formatter return (byte)fixedHeader; } - public static ArraySegment EncodeVariableByteInteger(uint value) + public static ArraySegment EncodeVariableLengthInteger(uint value) { - if (value <= 0) + if (value == 0) { return new ArraySegment(new byte[1], 0, 1); } + if (value <= 127) + { + return new ArraySegment(new[] { (byte)value }, 0, 1); + } + var buffer = new byte[4]; var bufferOffset = 0; @@ -57,7 +63,7 @@ namespace MQTTnet.Formatter public void WriteVariableLengthInteger(uint value) { - Write(EncodeVariableByteInteger(value)); + Write(EncodeVariableLengthInteger(value)); } public void WriteWithLengthPrefix(string value) @@ -80,7 +86,7 @@ namespace MQTTnet.Formatter EnsureAdditionalCapacity(1); _buffer[_offset] = @byte; - IncreasePostition(1); + IncreasePosition(1); } public void Write(ushort value) @@ -88,9 +94,9 @@ namespace MQTTnet.Formatter EnsureAdditionalCapacity(2); _buffer[_offset] = (byte)(value >> 8); - IncreasePostition(1); + IncreasePosition(1); _buffer[_offset] = (byte)value; - IncreasePostition(1); + IncreasePosition(1); } public void Write(byte[] buffer, int offset, int count) @@ -100,7 +106,7 @@ namespace MQTTnet.Formatter EnsureAdditionalCapacity(count); Array.Copy(buffer, offset, _buffer, _offset, count); - IncreasePostition(count); + IncreasePosition(count); } public void Write(ArraySegment buffer) @@ -122,9 +128,9 @@ namespace MQTTnet.Formatter Write(propertyWriter._buffer, 0, propertyWriter.Length); } - public void Reset() + public void Reset(int length) { - Length = 5; + Length = length; } public void Seek(int position) @@ -185,7 +191,7 @@ namespace MQTTnet.Formatter } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void IncreasePostition(int length) + private void IncreasePosition(int length) { _offset += length; diff --git a/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs b/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs index e9f8602..923b9c6 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV310PacketFormatter.cs @@ -21,13 +21,13 @@ namespace MQTTnet.Formatter.V3 if (packet == null) throw new ArgumentNullException(nameof(packet)); // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) - _packetWriter.Reset(); + _packetWriter.Reset(5); _packetWriter.Seek(5); var fixedHeader = EncodePacket(packet, _packetWriter); var remainingLength = (uint)(_packetWriter.Length - 5); - var remainingLengthBuffer = MqttPacketWriter.EncodeVariableByteInteger(remainingLength); + var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength); var headerSize = FixedHeaderSize + remainingLengthBuffer.Count; var headerOffset = 5 - headerSize; diff --git a/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs b/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs index 0c0019c..2dfde06 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV500PacketEncoder.cs @@ -15,13 +15,13 @@ namespace MQTTnet.Formatter.V5 if (packet == null) throw new ArgumentNullException(nameof(packet)); // Leave enough head space for max header size (fixed + 4 variable remaining length = 5 bytes) - _packetWriter.Reset(); + _packetWriter.Reset(5); _packetWriter.Seek(5); var fixedHeader = EncodePacket(packet, _packetWriter); var remainingLength = (uint)(_packetWriter.Length - 5); - var remainingLengthBuffer = MqttPacketWriter.EncodeVariableByteInteger(remainingLength); + var remainingLengthBuffer = MqttPacketWriter.EncodeVariableLengthInteger(remainingLength); var headerSize = 1 + remainingLengthBuffer.Count; var headerOffset = 5 - headerSize; diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 978f8b6..e389ef0 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -52,11 +52,15 @@ namespace MQTTnet.Implementations _socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; } + // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 + using (cancellationToken.Register(() => _socket.Dispose())) + { #if NET452 || NET461 - await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false); + await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false); #else - await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); + await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); #endif + } SslStream sslStream = null; if (_options.TlsOptions.UseTls) @@ -74,14 +78,23 @@ namespace MQTTnet.Implementations return Task.FromResult(0); } - public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return _stream.ReadAsync(buffer, offset, count, cancellationToken); + // Workaround for: https://github.com/dotnet/corefx/issues/24430 + using (cancellationToken.Register(() => _socket.Dispose())) + { + return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + } } - public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return _stream.WriteAsync(buffer, offset, count, cancellationToken); + // Workaround for: https://github.com/dotnet/corefx/issues/24430 + using (cancellationToken.Register(() => _socket.Dispose())) + { + await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + await _stream.FlushAsync(cancellationToken); + } } public void Dispose() @@ -89,10 +102,11 @@ namespace MQTTnet.Implementations Cleanup(ref _stream, s => s.Dispose()); Cleanup(ref _socket, s => { - if (s.Connected) - { - s.Shutdown(SocketShutdown.Both); - } + //if (s.Connected) + //{ + // s.Shutdown(SocketShutdown.Both); + //} + s.Dispose(); }); } diff --git a/Source/MQTTnet/Internal/AsyncLock.cs b/Source/MQTTnet/Internal/AsyncLock.cs index 87571c2..9b7eefd 100644 --- a/Source/MQTTnet/Internal/AsyncLock.cs +++ b/Source/MQTTnet/Internal/AsyncLock.cs @@ -15,14 +15,23 @@ namespace MQTTnet.Internal _releaser = Task.FromResult((IDisposable)new Releaser(this)); } - public Task LockAsync(CancellationToken cancellationToken) + public Task WaitAsync() { - var wait = _semaphore.WaitAsync(cancellationToken); - return wait.IsCompleted ? - _releaser : - wait.ContinueWith((_, state) => (IDisposable)state, - _releaser.Result, cancellationToken, - TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + return WaitAsync(CancellationToken.None); + } + + public Task WaitAsync(CancellationToken cancellationToken) + { + var task = _semaphore.WaitAsync(cancellationToken); + if (task.IsCompleted) + { + return _releaser; + } + + return task.ContinueWith( + (_, state) => (IDisposable)state, + _releaser.Result, + cancellationToken, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); } public void Dispose() diff --git a/Source/MQTTnet/Server/IMqttClientSession.cs b/Source/MQTTnet/Server/IMqttClientSession.cs index 51341e9..f1e3010 100644 --- a/Source/MQTTnet/Server/IMqttClientSession.cs +++ b/Source/MQTTnet/Server/IMqttClientSession.cs @@ -1,4 +1,5 @@ using System; +using System.Threading.Tasks; namespace MQTTnet.Server { @@ -6,6 +7,6 @@ namespace MQTTnet.Server { string ClientId { get; } - void Stop(MqttClientDisconnectType disconnectType); + Task StopAsync(MqttClientDisconnectType disconnectType); } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/IMqttServer.cs b/Source/MQTTnet/Server/IMqttServer.cs index 3a45973..e42c903 100644 --- a/Source/MQTTnet/Server/IMqttServer.cs +++ b/Source/MQTTnet/Server/IMqttServer.cs @@ -16,11 +16,8 @@ namespace MQTTnet.Server IMqttServerOptions Options { get; } - [Obsolete("This method is no longer async. Use the not async method.")] Task> GetClientSessionsStatusAsync(); - IList GetClientSessionsStatus(); - IList GetRetainedMessages(); Task ClearRetainedMessagesAsync(); diff --git a/Source/MQTTnet/Server/IMqttServerApplicationMessageInterceptor.cs b/Source/MQTTnet/Server/IMqttServerApplicationMessageInterceptor.cs new file mode 100644 index 0000000..64fe047 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerApplicationMessageInterceptor.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttServerApplicationMessageInterceptor + { + Task InterceptApplicationMessagePublishAsync(MqttApplicationMessageInterceptorContext context); + } +} diff --git a/Source/MQTTnet/Server/IMqttServerClientMessageQueueInterceptor.cs b/Source/MQTTnet/Server/IMqttServerClientMessageQueueInterceptor.cs new file mode 100644 index 0000000..c6f97a7 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerClientMessageQueueInterceptor.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttServerClientMessageQueueInterceptor + { + Task InterceptClientMessageQueueEnqueueAsync(MqttClientMessageQueueInterceptorContext context); + } +} diff --git a/Source/MQTTnet/Server/IMqttServerConnectionValidator.cs b/Source/MQTTnet/Server/IMqttServerConnectionValidator.cs new file mode 100644 index 0000000..bd38c83 --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerConnectionValidator.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttServerConnectionValidator + { + Task ValidateConnection(MqttConnectionValidatorContext context); + } +} diff --git a/Source/MQTTnet/Server/IMqttServerOptions.cs b/Source/MQTTnet/Server/IMqttServerOptions.cs index f9334fd..d840c8f 100644 --- a/Source/MQTTnet/Server/IMqttServerOptions.cs +++ b/Source/MQTTnet/Server/IMqttServerOptions.cs @@ -11,10 +11,10 @@ namespace MQTTnet.Server TimeSpan DefaultCommunicationTimeout { get; } - Action ConnectionValidator { get; } - Action SubscriptionInterceptor { get; } - Action ApplicationMessageInterceptor { get; } - Action ClientMessageQueueInterceptor { get; } + IMqttServerConnectionValidator ConnectionValidator { get; } + IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; } + IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; } + IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; } MqttServerTcpEndpointOptions DefaultEndpointOptions { get; } MqttServerTlsTcpEndpointOptions TlsEndpointOptions { get; } diff --git a/Source/MQTTnet/Server/IMqttServerSubscriptionInterceptor.cs b/Source/MQTTnet/Server/IMqttServerSubscriptionInterceptor.cs new file mode 100644 index 0000000..a7ce95e --- /dev/null +++ b/Source/MQTTnet/Server/IMqttServerSubscriptionInterceptor.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttServerSubscriptionInterceptor + { + Task InterceptSubscriptionAsync(MqttSubscriptionInterceptorContext context); + } +} diff --git a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs index fbe1f87..e8e7c15 100644 --- a/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs +++ b/Source/MQTTnet/Server/MqttClientKeepAliveMonitor.cs @@ -81,7 +81,7 @@ namespace MQTTnet.Server if (!_isPaused && _lastPacketReceivedTracker.Elapsed.TotalSeconds >= keepAlivePeriod * 1.5D) { _logger.Warning(null, "Client '{0}': Did not receive any packet or keep alive signal.", _clientSession.ClientId); - _clientSession.Stop(MqttClientDisconnectType.NotClean); + await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); return; } diff --git a/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs index 503d992..eddd788 100644 --- a/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs +++ b/Source/MQTTnet/Server/MqttClientPendingPacketsQueue.cs @@ -133,7 +133,7 @@ namespace MQTTnet.Server return; } - adapter.SendPacketAsync(packet, cancellationToken).GetAwaiter().GetResult(); + await adapter.SendPacketAsync(packet, cancellationToken).ConfigureAwait(false); _logger.Verbose("Enqueued packet sent (ClientId: {0}).", _clientSession.ClientId); } @@ -167,7 +167,7 @@ namespace MQTTnet.Server if (!cancellationToken.IsCancellationRequested) { - _clientSession.Stop(MqttClientDisconnectType.NotClean); + await _clientSession.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); } } } diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 6eacfa6..9eaab4f 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -79,22 +79,20 @@ namespace MQTTnet.Server return _workerTask; } - public void Stop(MqttClientDisconnectType type) + public Task StopAsync(MqttClientDisconnectType type) { - Stop(type, false); + return StopAsync(type, false); } - public Task SubscribeAsync(IList topicFilters) + public async Task SubscribeAsync(IList topicFilters) { if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); var packet = new MqttSubscribePacket(); packet.TopicFilters.AddRange(topicFilters); - _subscriptionsManager.Subscribe(packet); - - EnqueueSubscribedRetainedMessages(topicFilters); - return Task.FromResult(0); + await _subscriptionsManager.SubscribeAsync(packet).ConfigureAwait(false); + await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false); } public Task UnsubscribeAsync(IList topicFilters) @@ -122,7 +120,7 @@ namespace MQTTnet.Server _cancellationTokenSource?.Dispose(); _cancellationTokenSource = null; } - + private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) { if (adapter == null) throw new ArgumentNullException(nameof(adapter)); @@ -131,7 +129,7 @@ namespace MQTTnet.Server { if (_cancellationTokenSource != null) { - Stop(MqttClientDisconnectType.Clean, true); + await StopAsync(MqttClientDisconnectType.Clean, true).ConfigureAwait(false); } adapter.ReadingPacketStarted += OnAdapterReadingPacketStarted; @@ -139,15 +137,15 @@ namespace MQTTnet.Server _cancellationTokenSource = new CancellationTokenSource(); - //workaround for https://github.com/dotnet/corefx/issues/24430 -#pragma warning disable 4014 - _cleanupHandle = _cancellationTokenSource.Token.Register(async () => - { - await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false); - TryDisposeAdapter(adapter); - }); -#pragma warning restore 4014 - //end workaround +// //workaround for https://github.com/dotnet/corefx/issues/24430 +//#pragma warning disable 4014 +// _cleanupHandle = _cancellationTokenSource.Token.Register(async () => +// { +// await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false); +// TryDisposeAdapter(adapter); +// }); +//#pragma warning restore 4014 +// //end workaround _wasCleanDisconnect = false; _willMessage = connectPacket.WillMessage; @@ -165,7 +163,7 @@ namespace MQTTnet.Server if (packet != null) { _keepAliveMonitor.PacketReceived(packet); - ProcessReceivedPacket(adapter, packet, _cancellationTokenSource.Token); + await ProcessReceivedPacketAsync(adapter, packet, _cancellationTokenSource.Token).ConfigureAwait(false); } } } @@ -190,51 +188,55 @@ namespace MQTTnet.Server _logger.Error(exception, "Client '{0}': Unhandled exception while receiving client packets.", ClientId); } - Stop(MqttClientDisconnectType.NotClean, true); + await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false); } finally { + adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; + adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; + _adapterEndpoint = null; _adapterProtocolVersion = null; // Uncomment as soon as the workaround above is no longer needed. + // Also called in outer scope! //await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false); //TryDisposeAdapter(adapter); _cleanupHandle?.Dispose(); _cleanupHandle = null; - + _cancellationTokenSource?.Cancel(false); _cancellationTokenSource?.Dispose(); _cancellationTokenSource = null; } } - private void TryDisposeAdapter(IMqttChannelAdapter adapter) - { - if (adapter == null) - { - return; - } - - try - { - adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; - adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; - - adapter.Dispose(); - } - catch (Exception exception) - { - _logger.Error(exception, exception.Message); - } - finally - { - adapter.Dispose(); - } - } - - private void Stop(MqttClientDisconnectType type, bool isInsideSession) + ////private void TryDisposeAdapter(IMqttChannelAdapter adapter) + ////{ + //// if (adapter == null) + //// { + //// return; + //// } + + //// try + //// { + //// adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; + //// adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; + + //// adapter.Dispose(); + //// } + //// catch (Exception exception) + //// { + //// _logger.Error(exception, exception.Message); + //// } + //// finally + //// { + //// adapter.Dispose(); + //// } + ////} + + private async Task StopAsync(MqttClientDisconnectType type, bool isInsideSession) { try { @@ -257,7 +259,10 @@ namespace MQTTnet.Server if (!isInsideSession) { - _workerTask?.GetAwaiter().GetResult(); + if (_workerTask != null) + { + await _workerTask.ConfigureAwait(false); + } } } finally @@ -267,7 +272,7 @@ namespace MQTTnet.Server } } - public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage, bool isRetainedApplicationMessage) + public async Task EnqueueApplicationMessageAsync(MqttClientSession senderClientSession, MqttApplicationMessage applicationMessage, bool isRetainedApplicationMessage) { if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); @@ -278,10 +283,10 @@ namespace MQTTnet.Server } var publishPacket = _channelAdapter.PacketFormatterAdapter.DataConverter.CreatePublishPacket(applicationMessage); - + // Set the retain flag to true according to [MQTT-3.3.1-8] and [MQTT-3.3.1-9]. publishPacket.Retain = isRetainedApplicationMessage; - + if (publishPacket.QualityOfServiceLevel > 0) { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); @@ -294,7 +299,10 @@ namespace MQTTnet.Server ClientId, applicationMessage); - _options.ClientMessageQueueInterceptor?.Invoke(context); + if (_options.ClientMessageQueueInterceptor != null) + { + await _options.ClientMessageQueueInterceptor.InterceptClientMessageQueueEnqueueAsync(context).ConfigureAwait(false); + } if (!context.AcceptEnqueue || context.ApplicationMessage == null) { @@ -309,35 +317,33 @@ namespace MQTTnet.Server _pendingPacketsQueue.Enqueue(publishPacket); } - private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter) - { - if (adapter == null) - { - return; - } - - try - { - await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); - } - catch (Exception exception) - { - _logger.Error(exception, "Error while disconnecting channel adapter."); - } - } - - private void ProcessReceivedPacket(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken) + //private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter) + //{ + // if (adapter == null) + // { + // return; + // } + + // try + // { + // await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); + // } + // catch (Exception exception) + // { + // _logger.Error(exception, "Error while disconnecting channel adapter."); + // } + //} + + private Task ProcessReceivedPacketAsync(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken) { if (packet is MqttPublishPacket publishPacket) { - HandleIncomingPublishPacket(adapter, publishPacket, cancellationToken); - return; + return HandleIncomingPublishPacketAsync(adapter, publishPacket, cancellationToken); } if (packet is MqttPingReqPacket) { - adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken).GetAwaiter().GetResult(); - return; + return adapter.SendPacketAsync(new MqttPingRespPacket(), cancellationToken); } if (packet is MqttPubRelPacket pubRelPacket) @@ -348,8 +354,7 @@ namespace MQTTnet.Server ReasonCode = MqttPubCompReasonCode.Success }; - adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult(); - return; + return adapter.SendPacketAsync(responsePacket, cancellationToken); } if (packet is MqttPubRecPacket pubRecPacket) @@ -360,91 +365,83 @@ namespace MQTTnet.Server ReasonCode = MqttPubRelReasonCode.Success }; - adapter.SendPacketAsync(responsePacket, cancellationToken).GetAwaiter().GetResult(); - return; + return adapter.SendPacketAsync(responsePacket, cancellationToken); } if (packet is MqttPubAckPacket || packet is MqttPubCompPacket) { - return; + return Task.FromResult(0); } if (packet is MqttSubscribePacket subscribePacket) { - HandleIncomingSubscribePacket(adapter, subscribePacket, cancellationToken); - return; + return HandleIncomingSubscribePacketAsync(adapter, subscribePacket, cancellationToken); } if (packet is MqttUnsubscribePacket unsubscribePacket) { - HandleIncomingUnsubscribePacket(adapter, unsubscribePacket, cancellationToken); - return; + return HandleIncomingUnsubscribePacketAsync(adapter, unsubscribePacket, cancellationToken); } if (packet is MqttDisconnectPacket) { - Stop(MqttClientDisconnectType.Clean, true); - return; + return StopAsync(MqttClientDisconnectType.Clean, true); } if (packet is MqttConnectPacket) { - Stop(MqttClientDisconnectType.NotClean, true); - return; + return StopAsync(MqttClientDisconnectType.NotClean, true); } _logger.Warning(null, "Client '{0}': Received not supported packet ({1}). Closing connection.", ClientId, packet); - Stop(MqttClientDisconnectType.NotClean, true); + return StopAsync(MqttClientDisconnectType.NotClean, true); } - private void EnqueueSubscribedRetainedMessages(ICollection topicFilters) + private async Task EnqueueSubscribedRetainedMessagesAsync(ICollection topicFilters) { - var retainedMessages = _retainedMessagesManager.GetSubscribedMessages(topicFilters); + var retainedMessages = await _retainedMessagesManager.GetSubscribedMessagesAsync(topicFilters).ConfigureAwait(false); foreach (var applicationMessage in retainedMessages) { - EnqueueApplicationMessage(null, applicationMessage, true); + await EnqueueApplicationMessageAsync(null, applicationMessage, true).ConfigureAwait(false); } } - private void HandleIncomingSubscribePacket(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) + private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken) { - var subscribeResult = _subscriptionsManager.Subscribe(subscribePacket); - adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).GetAwaiter().GetResult(); + var subscribeResult = await _subscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false); + await adapter.SendPacketAsync(subscribeResult.ResponsePacket, cancellationToken).ConfigureAwait(false); if (subscribeResult.CloseConnection) { - Stop(MqttClientDisconnectType.NotClean, true); - return; + await StopAsync(MqttClientDisconnectType.NotClean, true).ConfigureAwait(false); } - EnqueueSubscribedRetainedMessages(subscribePacket.TopicFilters); + await EnqueueSubscribedRetainedMessagesAsync(subscribePacket.TopicFilters).ConfigureAwait(false); } - private void HandleIncomingUnsubscribePacket(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) + private Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken) { var unsubscribeResult = _subscriptionsManager.Unsubscribe(unsubscribePacket); - adapter.SendPacketAsync(unsubscribeResult, cancellationToken).GetAwaiter().GetResult(); + return adapter.SendPacketAsync(unsubscribeResult, cancellationToken); } - private void HandleIncomingPublishPacket(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) + private Task HandleIncomingPublishPacketAsync(IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { switch (publishPacket.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: { HandleIncomingPublishPacketWithQoS0(publishPacket); - break; + return Task.FromResult(0); } case MqttQualityOfServiceLevel.AtLeastOnce: { - HandleIncomingPublishPacketWithQoS1(adapter, publishPacket, cancellationToken); - break; + return HandleIncomingPublishPacketWithQoS1Async(adapter, publishPacket, cancellationToken); } case MqttQualityOfServiceLevel.ExactlyOnce: { - HandleIncomingPublishPacketWithQoS2(adapter, publishPacket, cancellationToken); - break; + return HandleIncomingPublishPacketWithQoS2Async(adapter, publishPacket, cancellationToken); } default: { @@ -456,17 +453,17 @@ namespace MQTTnet.Server private void HandleIncomingPublishPacketWithQoS0(MqttPublishPacket publishPacket) { _sessionsManager.EnqueueApplicationMessage( - this, + this, _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); } - private void HandleIncomingPublishPacketWithQoS1( + private Task HandleIncomingPublishPacketWithQoS1Async( IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) { _sessionsManager.EnqueueApplicationMessage( - this, + this, _channelAdapter.PacketFormatterAdapter.DataConverter.CreateApplicationMessage(publishPacket)); var response = new MqttPubAckPacket @@ -475,10 +472,10 @@ namespace MQTTnet.Server ReasonCode = MqttPubAckReasonCode.Success }; - adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult(); + return adapter.SendPacketAsync(response, cancellationToken); } - private void HandleIncomingPublishPacketWithQoS2( + private Task HandleIncomingPublishPacketWithQoS2Async( IMqttChannelAdapter adapter, MqttPublishPacket publishPacket, CancellationToken cancellationToken) @@ -492,7 +489,7 @@ namespace MQTTnet.Server ReasonCode = MqttPubRecReasonCode.Success }; - adapter.SendPacketAsync(response, cancellationToken).GetAwaiter().GetResult(); + return adapter.SendPacketAsync(response, cancellationToken); } private void OnAdapterReadingPacketCompleted(object sender, EventArgs e) diff --git a/Source/MQTTnet/Server/MqttClientSessionStatus.cs b/Source/MQTTnet/Server/MqttClientSessionStatus.cs index 7381f94..2673d7e 100644 --- a/Source/MQTTnet/Server/MqttClientSessionStatus.cs +++ b/Source/MQTTnet/Server/MqttClientSessionStatus.cs @@ -25,22 +25,19 @@ namespace MQTTnet.Server public Task DisconnectAsync() { - _session.Stop(MqttClientDisconnectType.NotClean); - return Task.FromResult(0); + return _session.StopAsync(MqttClientDisconnectType.NotClean); } - public Task DeleteSessionAsync() + public async Task DeleteSessionAsync() { try { - _session.Stop(MqttClientDisconnectType.NotClean); + await _session.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); } finally { - _sessionsManager.DeleteSession(ClientId); + await _sessionsManager.DeleteSessionAsync(ClientId).ConfigureAwait(false); } - - return Task.FromResult(0); } public Task ClearPendingApplicationMessagesAsync() diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index 748594a..356782e 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -1,12 +1,11 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics; -using MQTTnet.Exceptions; +using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -16,9 +15,7 @@ namespace MQTTnet.Server { private readonly BlockingCollection _messageQueue = new BlockingCollection(); - /// - /// manual locking dictionaries is faster than using concurrent dictionary - /// + private readonly AsyncLock _sessionsLock = new AsyncLock(); private readonly Dictionary _sessions = new Dictionary(); private readonly CancellationToken _cancellationToken; @@ -47,16 +44,16 @@ namespace MQTTnet.Server public void Start() { - Task.Factory.StartNew(() => TryProcessQueuedApplicationMessages(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); + Task.Factory.StartNew(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default); } - public void Stop() + public async Task StopAsync() { - lock (_sessions) + using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) { foreach (var session in _sessions) { - session.Value.Stop(MqttClientDisconnectType.NotClean); + await session.Value.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); } _sessions.Clear(); @@ -68,18 +65,21 @@ namespace MQTTnet.Server return Task.Run(() => RunSessionAsync(clientAdapter, _cancellationToken), _cancellationToken); } - public IList GetClientStatus() + public async Task> GetClientStatusAsync() { var result = new List(); - foreach (var session in GetSessions()) + using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) { - var status = new MqttClientSessionStatus(this, session); - session.FillStatus(status); + foreach (var session in _sessions.Values) + { + var status = new MqttClientSessionStatus(this, session); + session.FillStatus(status); - result.Add(status); + result.Add(status); + } } - + return result; } @@ -122,9 +122,9 @@ namespace MQTTnet.Server } } - public void DeleteSession(string clientId) + public async Task DeleteSessionAsync(string clientId) { - lock (_sessions) + using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) { _sessions.Remove(clientId); } @@ -137,13 +137,13 @@ namespace MQTTnet.Server _messageQueue?.Dispose(); } - private void TryProcessQueuedApplicationMessages(CancellationToken cancellationToken) + private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken) { while (!cancellationToken.IsCancellationRequested) { try { - TryProcessNextQueuedApplicationMessage(cancellationToken); + await TryProcessNextQueuedApplicationMessageAsync(cancellationToken).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -155,7 +155,7 @@ namespace MQTTnet.Server } } - private void TryProcessNextQueuedApplicationMessage(CancellationToken cancellationToken) + private async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken) { try { @@ -164,12 +164,12 @@ namespace MQTTnet.Server var sender = enqueuedApplicationMessage.Sender; var applicationMessage = enqueuedApplicationMessage.ApplicationMessage; - var interceptorContext = InterceptApplicationMessage(sender, applicationMessage); + var interceptorContext = await InterceptApplicationMessageAsync(sender, applicationMessage).ConfigureAwait(false); if (interceptorContext != null) { if (interceptorContext.CloseConnection) { - enqueuedApplicationMessage.Sender.Stop(MqttClientDisconnectType.NotClean); + await enqueuedApplicationMessage.Sender.StopAsync(MqttClientDisconnectType.NotClean).ConfigureAwait(false); } if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish) @@ -184,12 +184,18 @@ namespace MQTTnet.Server if (applicationMessage.Retain) { - _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).GetAwaiter().GetResult(); + await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); } - foreach (var clientSession in GetSessions()) + using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) { - clientSession.EnqueueApplicationMessage(enqueuedApplicationMessage.Sender, enqueuedApplicationMessage.ApplicationMessage, false); + foreach (var clientSession in _sessions.Values) + { + await clientSession.EnqueueApplicationMessageAsync( + enqueuedApplicationMessage.Sender, + enqueuedApplicationMessage.ApplicationMessage, + false).ConfigureAwait(false); + } } } catch (OperationCanceledException) @@ -201,35 +207,22 @@ namespace MQTTnet.Server } } - private List GetSessions() - { - lock (_sessions) - { - return _sessions.Values.ToList(); - } - } - private async Task RunSessionAsync(IMqttChannelAdapter clientAdapter, CancellationToken cancellationToken) { var clientId = string.Empty; try { - // TODO: Catch cancel exception here if the first packet was not received and log properly. var firstPacket = await clientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); - if (firstPacket == null) - { - return; - } - if (!(firstPacket is MqttConnectPacket connectPacket)) { - throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1]."); + _logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", clientAdapter.Endpoint); + return; } clientId = connectPacket.ClientId; - var connectReturnCode = ValidateConnection(connectPacket, clientAdapter); + var connectReturnCode = await ValidateConnectionAsync(connectPacket, clientAdapter).ConfigureAwait(false); if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted) { await clientAdapter.SendPacketAsync( @@ -242,7 +235,7 @@ namespace MQTTnet.Server return; } - var result = PrepareClientSession(connectPacket); + var result = await PrepareClientSessionAsync(connectPacket).ConfigureAwait(false); await clientAdapter.SendPacketAsync( new MqttConnAckPacket @@ -267,14 +260,17 @@ namespace MQTTnet.Server } finally { + await clientAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); + clientAdapter.Dispose(); + if (!_options.EnablePersistentSessions) { - DeleteSession(clientId); + await DeleteSessionAsync(clientId).ConfigureAwait(false); } } } - private MqttConnectReturnCode ValidateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) + private async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter) { if (_options.ConnectionValidator == null) { @@ -288,13 +284,13 @@ namespace MQTTnet.Server connectPacket.WillMessage, clientAdapter.Endpoint); - _options.ConnectionValidator(context); + await _options.ConnectionValidator.ValidateConnection(context).ConfigureAwait(false); return context.ReturnCode; } - private PrepareClientSessionResult PrepareClientSession(MqttConnectPacket connectPacket) + private async Task PrepareClientSessionAsync(MqttConnectPacket connectPacket) { - lock (_sessions) + using (await _sessionsLock.WaitAsync(_cancellationToken).ConfigureAwait(false)) { var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var clientSession); if (isSessionPresent) @@ -303,7 +299,7 @@ namespace MQTTnet.Server { _sessions.Remove(connectPacket.ClientId); - clientSession.Stop(MqttClientDisconnectType.Clean); + await clientSession.StopAsync(MqttClientDisconnectType.Clean).ConfigureAwait(false); clientSession.Dispose(); clientSession = null; @@ -330,7 +326,7 @@ namespace MQTTnet.Server } } - private MqttApplicationMessageInterceptorContext InterceptApplicationMessage(MqttClientSession sender, MqttApplicationMessage applicationMessage) + private async Task InterceptApplicationMessageAsync(MqttClientSession sender, MqttApplicationMessage applicationMessage) { var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) @@ -339,7 +335,7 @@ namespace MQTTnet.Server } var interceptorContext = new MqttApplicationMessageInterceptorContext(sender?.ClientId, applicationMessage); - interceptor(interceptorContext); + await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false); return interceptorContext; } } diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index 9f1fff6..3186104 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -20,7 +21,7 @@ namespace MQTTnet.Server _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher)); } - public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket) + public async Task SubscribeAsync(MqttSubscribePacket subscribePacket) { if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); @@ -36,7 +37,7 @@ namespace MQTTnet.Server foreach (var topicFilter in subscribePacket.TopicFilters) { - var interceptorContext = InterceptSubscribe(topicFilter); + var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false); if (!interceptorContext.AcceptSubscription) { result.ResponsePacket.ReturnCodes.Add(MqttSubscribeReturnCode.Failure); @@ -146,11 +147,15 @@ namespace MQTTnet.Server } } - private MqttSubscriptionInterceptorContext InterceptSubscribe(TopicFilter topicFilter) + private async Task InterceptSubscribeAsync(TopicFilter topicFilter) { - var interceptorContext = new MqttSubscriptionInterceptorContext(_clientId, topicFilter); - _options.SubscriptionInterceptor?.Invoke(interceptorContext); - return interceptorContext; + var context = new MqttSubscriptionInterceptorContext(_clientId, topicFilter); + if (_options.SubscriptionInterceptor != null) + { + await _options.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false); + } + + return context; } private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet subscribedQoSLevels) diff --git a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs index bf5de98..52533d1 100644 --- a/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs +++ b/Source/MQTTnet/Server/MqttRetainedMessagesManager.cs @@ -3,11 +3,14 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using MQTTnet.Diagnostics; +using MQTTnet.Internal; namespace MQTTnet.Server { public class MqttRetainedMessagesManager { + private readonly byte[] _emptyArray = new byte[0]; + private readonly AsyncLock _messagesLock = new AsyncLock(); private readonly Dictionary _messages = new Dictionary(); private readonly IMqttNetChildLogger _logger; @@ -31,7 +34,7 @@ namespace MQTTnet.Server { var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false); - lock (_messages) + using (await _messagesLock.WaitAsync().ConfigureAwait(false)) { _messages.Clear(); foreach (var retainedMessage in retainedMessages) @@ -52,8 +55,7 @@ namespace MQTTnet.Server try { - List messagesForSave = null; - lock (_messages) + using (await _messagesLock.WaitAsync().ConfigureAwait(false)) { var saveIsRequired = false; @@ -71,7 +73,7 @@ namespace MQTTnet.Server } else { - if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? new byte[0])) + if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? _emptyArray)) { _messages[applicationMessage.Topic] = applicationMessage; saveIsRequired = true; @@ -83,20 +85,13 @@ namespace MQTTnet.Server if (saveIsRequired) { - messagesForSave = new List(_messages.Values); + if (_options.Storage != null) + { + var messagesForSave = new List(_messages.Values); + await _options.Storage.SaveRetainedMessagesAsync(messagesForSave).ConfigureAwait(false); + } } } - - if (messagesForSave == null) - { - _logger.Verbose("Skipped saving retained messages because no changes were detected."); - return; - } - - if (_options.Storage != null) - { - await _options.Storage.SaveRetainedMessagesAsync(messagesForSave).ConfigureAwait(false); - } } catch (Exception exception) { @@ -104,11 +99,11 @@ namespace MQTTnet.Server } } - public IList GetSubscribedMessages(ICollection topicFilters) + public async Task> GetSubscribedMessagesAsync(ICollection topicFilters) { var retainedMessages = new List(); - lock (_messages) + using (await _messagesLock.WaitAsync().ConfigureAwait(false)) { foreach (var retainedMessage in _messages.Values) { @@ -128,27 +123,25 @@ namespace MQTTnet.Server return retainedMessages; } - public IList GetMessages() + public async Task> GetMessagesAsync() { - lock (_messages) + using (await _messagesLock.WaitAsync().ConfigureAwait(false)) { return _messages.Values.ToList(); } } - public Task ClearMessagesAsync() + public async Task ClearMessagesAsync() { - lock (_messages) + using (await _messagesLock.WaitAsync().ConfigureAwait(false)) { _messages.Clear(); - } - if (_options.Storage != null) - { - return _options.Storage.SaveRetainedMessagesAsync(new List()); + if (_options.Storage != null) + { + await _options.Storage.SaveRetainedMessagesAsync(new List()).ConfigureAwait(false); + } } - - return Task.FromResult((object)null); } } } diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index 15e23e7..aae37e6 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -48,17 +48,12 @@ namespace MQTTnet.Server public Task> GetClientSessionsStatusAsync() { - return Task.FromResult(_clientSessionsManager.GetClientStatus()); - } - - public IList GetClientSessionsStatus() - { - return _clientSessionsManager.GetClientStatus(); + return _clientSessionsManager.GetClientStatusAsync(); } public IList GetRetainedMessages() { - return _retainedMessagesManager.GetMessages(); + return _retainedMessagesManager.GetMessagesAsync().GetAwaiter().GetResult(); } public Task SubscribeAsync(string clientId, IList topicFilters) @@ -123,7 +118,7 @@ namespace MQTTnet.Server _cancellationTokenSource.Cancel(false); - _clientSessionsManager.Stop(); + _clientSessionsManager.StopAsync().ConfigureAwait(false); foreach (var adapter in _adapters) { diff --git a/Source/MQTTnet/Server/MqttServerApplicationMessageInterceptorDelegate.cs b/Source/MQTTnet/Server/MqttServerApplicationMessageInterceptorDelegate.cs new file mode 100644 index 0000000..82ba0f3 --- /dev/null +++ b/Source/MQTTnet/Server/MqttServerApplicationMessageInterceptorDelegate.cs @@ -0,0 +1,31 @@ +using System; +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public class MqttServerApplicationMessageInterceptorDelegate : IMqttServerApplicationMessageInterceptor + { + private readonly Func _callback; + + public MqttServerApplicationMessageInterceptorDelegate(Action callback) + { + if (callback == null) throw new ArgumentNullException(nameof(callback)); + + _callback = context => + { + callback(context); + return Task.FromResult(0); + }; + } + + public MqttServerApplicationMessageInterceptorDelegate(Func callback) + { + _callback = callback ?? throw new ArgumentNullException(nameof(callback)); + } + + public Task InterceptApplicationMessagePublishAsync(MqttApplicationMessageInterceptorContext context) + { + return _callback(context); + } + } +} diff --git a/Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs b/Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs new file mode 100644 index 0000000..954d10d --- /dev/null +++ b/Source/MQTTnet/Server/MqttServerConnectionValidatorDelegate.cs @@ -0,0 +1,31 @@ +using System; +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public class MqttServerConnectionValidatorDelegate : IMqttServerConnectionValidator + { + private readonly Func _callback; + + public MqttServerConnectionValidatorDelegate(Action callback) + { + if (callback == null) throw new ArgumentNullException(nameof(callback)); + + _callback = context => + { + callback(context); + return Task.FromResult(0); + }; + } + + public MqttServerConnectionValidatorDelegate(Func callback) + { + _callback = callback ?? throw new ArgumentNullException(nameof(callback)); + } + + public Task ValidateConnection(MqttConnectionValidatorContext context) + { + return _callback(context); + } + } +} diff --git a/Source/MQTTnet/Server/MqttServerOptions.cs b/Source/MQTTnet/Server/MqttServerOptions.cs index 02b2bce..f893dac 100644 --- a/Source/MQTTnet/Server/MqttServerOptions.cs +++ b/Source/MQTTnet/Server/MqttServerOptions.cs @@ -16,13 +16,13 @@ namespace MQTTnet.Server public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(15); - public Action ConnectionValidator { get; set; } + public IMqttServerConnectionValidator ConnectionValidator { get; set; } - public Action ApplicationMessageInterceptor { get; set; } + public IMqttServerApplicationMessageInterceptor ApplicationMessageInterceptor { get; set; } + + public IMqttServerClientMessageQueueInterceptor ClientMessageQueueInterceptor { get; set; } - public Action ClientMessageQueueInterceptor { get; set; } - - public Action SubscriptionInterceptor { get; set; } + public IMqttServerSubscriptionInterceptor SubscriptionInterceptor { get; set; } public IMqttServerStorage Storage { get; set; } } diff --git a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs index 4540062..f43ff86 100644 --- a/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs +++ b/Source/MQTTnet/Server/MqttServerOptionsBuilder.cs @@ -99,24 +99,42 @@ namespace MQTTnet.Server return this; } - public MqttServerOptionsBuilder WithConnectionValidator(Action value) + public MqttServerOptionsBuilder WithConnectionValidator(IMqttServerConnectionValidator value) { _options.ConnectionValidator = value; return this; } - public MqttServerOptionsBuilder WithApplicationMessageInterceptor(Action value) + public MqttServerOptionsBuilder WithConnectionValidator(Action value) + { + _options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(value); + return this; + } + + public MqttServerOptionsBuilder WithApplicationMessageInterceptor(IMqttServerApplicationMessageInterceptor value) { _options.ApplicationMessageInterceptor = value; return this; } - public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action value) + public MqttServerOptionsBuilder WithApplicationMessageInterceptor(Action value) + { + _options.ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(value); + return this; + } + + public MqttServerOptionsBuilder WithSubscriptionInterceptor(IMqttServerSubscriptionInterceptor value) { _options.SubscriptionInterceptor = value; return this; } + public MqttServerOptionsBuilder WithSubscriptionInterceptor(Action value) + { + _options.SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(value); + return this; + } + public MqttServerOptionsBuilder WithPersistentSessions() { _options.EnablePersistentSessions = true; diff --git a/Source/MQTTnet/Server/MqttServerSubscriptionInterceptorDelegate.cs b/Source/MQTTnet/Server/MqttServerSubscriptionInterceptorDelegate.cs new file mode 100644 index 0000000..f500e07 --- /dev/null +++ b/Source/MQTTnet/Server/MqttServerSubscriptionInterceptorDelegate.cs @@ -0,0 +1,31 @@ +using System; +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public class MqttServerSubscriptionInterceptorDelegate : IMqttServerSubscriptionInterceptor + { + private readonly Func _callback; + + public MqttServerSubscriptionInterceptorDelegate(Action callback) + { + if (callback == null) throw new ArgumentNullException(nameof(callback)); + + _callback = context => + { + callback(context); + return Task.FromResult(0); + }; + } + + public MqttServerSubscriptionInterceptorDelegate(Func callback) + { + _callback = callback ?? throw new ArgumentNullException(nameof(callback)); + } + + public Task InterceptSubscriptionAsync(MqttSubscriptionInterceptorContext context) + { + return _callback(context); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs index 81b0ad0..7073f70 100644 --- a/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs +++ b/Tests/MQTTnet.Core.Tests/AsyncLockTests.cs @@ -21,7 +21,7 @@ namespace MQTTnet.Tests #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed threads[i] = Task.Run(async () => { - using (var releaser = await @lock.LockAsync(CancellationToken.None)) + using (var releaser = await @lock.WaitAsync(CancellationToken.None)) { var localI = globalI; await Task.Delay(10); // Increase the chance for wrong data. diff --git a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs index 91b46de..75b088a 100644 --- a/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttKeepAliveMonitorTests.cs @@ -82,9 +82,10 @@ namespace MQTTnet.Tests throw new NotSupportedException(); } - public void Stop(MqttClientDisconnectType disconnectType) + public Task StopAsync(MqttClientDisconnectType disconnectType) { StopCalledCount++; + return Task.FromResult(0); } public Task SubscribeAsync(IList topicFilters) diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 83f82cd..b3f9f21 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -52,7 +53,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_WillMessage() + public async Task MqttServer_Will_Message() { var serverAdapter = new TestMqttServerAdapter(); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); @@ -84,7 +85,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_SubscribeUnsubscribe() + public async Task MqttServer_Subscribe_Unsubscribe() { var serverAdapter = new TestMqttServerAdapter(); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); @@ -230,7 +231,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_SessionTakeover() + public async Task MqttServer_Session_Takeover() { var server = new MqttFactory().CreateMqttServer(); try @@ -299,38 +300,43 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_HandleCleanDisconnect() + public async Task MqttServer_Handle_Clean_Disconnect() { - var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); - var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); - - var clientConnectedCalled = 0; - var clientDisconnectedCalled = 0; + var s = new MqttFactory().CreateMqttServer(); + try + { + var clientConnectedCalled = 0; + var clientDisconnectedCalled = 0; - s.ClientConnected += (_, __) => clientConnectedCalled++; - s.ClientDisconnected += (_, __) => clientDisconnectedCalled++; + s.ClientConnected += (_, __) => clientConnectedCalled++; + s.ClientDisconnected += (_, __) => clientDisconnectedCalled++; - var clientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("localhost") - .Build(); + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost") + .Build(); - await s.StartAsync(new MqttServerOptions()); + await s.StartAsync(new MqttServerOptions()); - var c1 = new MqttFactory().CreateMqttClient(); + var c1 = new MqttFactory().CreateMqttClient(); - await c1.ConnectAsync(clientOptions); + await c1.ConnectAsync(clientOptions); - await Task.Delay(100); + await Task.Delay(100); - await c1.DisconnectAsync(); + await c1.DisconnectAsync(); - await Task.Delay(100); + await Task.Delay(100); - await s.StopAsync(); + await s.StopAsync(); - await Task.Delay(100); + await Task.Delay(100); - Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled); + Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled); + } + finally + { + await s.StopAsync(); + } } [TestMethod] @@ -410,11 +416,13 @@ namespace MQTTnet.Tests .WithPayload("value" + j).WithRetainFlag().Build()).GetAwaiter().GetResult(); } + Thread.Sleep(100); + client.DisconnectAsync().GetAwaiter().GetResult(); } }); - await Task.Delay(100); + await Task.Delay(1000); var retainedMessages = server.GetRetainedMessages(); @@ -432,7 +440,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_RetainedMessagesFlow() + public async Task MqttServer_Retained_Messages_Flow() { var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); var serverAdapter = new TestMqttServerAdapter(); @@ -440,9 +448,9 @@ namespace MQTTnet.Tests await s.StartAsync(new MqttServerOptions()); var c1 = await serverAdapter.ConnectTestClient("c1"); await c1.PublishAsync(retainedMessage); - Thread.Sleep(500); + await Task.Delay(500); await c1.DisconnectAsync(); - Thread.Sleep(500); + await Task.Delay(500); var receivedMessages = 0; var c2 = await serverAdapter.ConnectTestClient("c2"); @@ -468,7 +476,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_NoRetainedMessage() + public async Task MqttServer_No_Retained_Message() { var serverAdapter = new TestMqttServerAdapter(); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); @@ -498,7 +506,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_RetainedMessage() + public async Task MqttServer_Retained_Message() { var serverAdapter = new TestMqttServerAdapter(); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); @@ -535,7 +543,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_ClearRetainedMessage() + public async Task MqttServer_Clear_Retained_Message() { var serverAdapter = new TestMqttServerAdapter(); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); @@ -567,7 +575,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_PersistRetainedMessage() + public async Task MqttServer_Persist_Retained_Message() { var storage = new TestStorage(); var serverAdapter = new TestMqttServerAdapter(); @@ -629,7 +637,7 @@ namespace MQTTnet.Tests try { - var options = new MqttServerOptions { ApplicationMessageInterceptor = Interceptor }; + var options = new MqttServerOptions { ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(c => Interceptor(c)) }; await s.StartAsync(options); @@ -692,7 +700,7 @@ namespace MQTTnet.Tests } [TestMethod] - public async Task MqttServer_ConnectionDenied() + public async Task MqttServer_Connection_Denied() { var server = new MqttFactory().CreateMqttServer(); var client = new MqttFactory().CreateMqttClient(); @@ -791,7 +799,6 @@ namespace MQTTnet.Tests Assert.AreEqual("cdcd", flow); } - [TestMethod] public async Task MqttServer_StopAndRestart() { @@ -820,6 +827,66 @@ namespace MQTTnet.Tests await server.StopAsync(); } + [TestMethod] + public async Task MqttServer_Close_Idle_Connection() + { + var server = new MqttFactory().CreateMqttServer(); + + try + { + await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build()); + + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync("localhost", 1883); + + // Don't send anything. The server should close the connection. + await Task.Delay(TimeSpan.FromSeconds(5)); + + try + { + await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + Assert.Fail("Receive should throw an exception."); + } + catch (SocketException) + { + } + } + finally + { + await server.StopAsync(); + } + } + + [TestMethod] + public async Task MqttServer_Send_Garbage() + { + var server = new MqttFactory().CreateMqttServer(); + + try + { + await server.StartAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(4)).Build()); + + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync("localhost", 1883); + await client.SendAsync(Encoding.UTF8.GetBytes("Garbage"), SocketFlags.None); + + await Task.Delay(TimeSpan.FromSeconds(5)); + + try + { + await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + Assert.Fail("Receive should throw an exception."); + } + catch (SocketException) + { + } + } + finally + { + await server.StopAsync(); + } + } + private class TestStorage : IMqttServerStorage { public IList Messages = new List(); diff --git a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs index 4af3b10..7b7e959 100644 --- a/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs @@ -16,7 +16,7 @@ namespace MQTTnet.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.Subscribe(sp); + sm.SubscribeAsync(sp).GetAwaiter().GetResult(); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce); Assert.IsTrue(result.IsSubscribed); @@ -31,7 +31,7 @@ namespace MQTTnet.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); - sm.Subscribe(sp); + sm.SubscribeAsync(sp).GetAwaiter().GetResult(); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); @@ -47,7 +47,7 @@ namespace MQTTnet.Tests sp.TopicFilters.Add(new TopicFilter { Topic = "#", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); sp.TopicFilters.Add(new TopicFilter { Topic = "A/B/C", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - sm.Subscribe(sp); + sm.SubscribeAsync(sp).GetAwaiter().GetResult(); var result = sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.ExactlyOnce); Assert.IsTrue(result.IsSubscribed); @@ -62,7 +62,7 @@ namespace MQTTnet.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.Subscribe(sp); + sm.SubscribeAsync(sp).GetAwaiter().GetResult(); Assert.IsFalse(sm.CheckSubscriptions("A/B/X", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); } @@ -75,7 +75,7 @@ namespace MQTTnet.Tests var sp = new MqttSubscribePacket(); sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build()); - sm.Subscribe(sp); + sm.SubscribeAsync(sp).GetAwaiter().GetResult(); Assert.IsTrue(sm.CheckSubscriptions("A/B/C", MqttQualityOfServiceLevel.AtMostOnce).IsSubscribed); diff --git a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs index 12fa1aa..b00e3ac 100644 --- a/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/ServerTest.cs @@ -23,7 +23,7 @@ namespace MQTTnet.TestApp.NetCore { var options = new MqttServerOptions { - ConnectionValidator = p => + ConnectionValidator = new MqttServerConnectionValidatorDelegate(p => { if (p.ClientId == "SpecialClient") { @@ -32,11 +32,11 @@ namespace MQTTnet.TestApp.NetCore p.ReturnCode = MqttConnectReturnCode.ConnectionRefusedBadUsernameOrPassword; } } - }, + }), Storage = new RetainedMessageHandler(), - ApplicationMessageInterceptor = context => + ApplicationMessageInterceptor = new MqttServerApplicationMessageInterceptorDelegate(context => { if (MqttTopicFilterComparer.IsMatch(context.ApplicationMessage.Topic, "/myTopic/WithTimestamp/#")) { @@ -50,8 +50,9 @@ namespace MQTTnet.TestApp.NetCore context.AcceptPublish = false; context.CloseConnection = true; } - }, - SubscriptionInterceptor = context => + }), + + SubscriptionInterceptor = new MqttServerSubscriptionInterceptorDelegate(context => { if (context.TopicFilter.Topic.StartsWith("admin/foo/bar") && context.ClientId != "theAdmin") { @@ -63,7 +64,7 @@ namespace MQTTnet.TestApp.NetCore context.AcceptSubscription = false; context.CloseConnection = true; } - } + }) }; // Extend the timestamp for all messages from clients. diff --git a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs index 44bf102..2065a97 100644 --- a/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs +++ b/Tests/MQTTnet.TestApp.UniversalWindows/MainPage.xaml.cs @@ -450,7 +450,7 @@ namespace MQTTnet.TestApp.UniversalWindows return; } - var sessions = _mqttServer.GetClientSessionsStatus(); + var sessions = _mqttServer.GetClientSessionsStatusAsync().GetAwaiter().GetResult(); _sessions.Clear(); foreach (var session in sessions) @@ -568,7 +568,7 @@ namespace MQTTnet.TestApp.UniversalWindows { var options = new MqttServerOptions(); - options.ConnectionValidator = c => + options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(c => { if (c.ClientId.Length < 10) { @@ -589,7 +589,7 @@ namespace MQTTnet.TestApp.UniversalWindows } c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; - }; + }); var factory = new MqttFactory(); var mqttServer = factory.CreateMqttServer(); @@ -633,7 +633,7 @@ namespace MQTTnet.TestApp.UniversalWindows { }; - options.ConnectionValidator = c => + options.ConnectionValidator = new MqttServerConnectionValidatorDelegate(c => { if (c.ClientId != "Highlander") { @@ -642,7 +642,7 @@ namespace MQTTnet.TestApp.UniversalWindows } c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; - }; + }); var mqttServer = new MqttFactory().CreateMqttServer(); await mqttServer.StartAsync(optionsBuilder.Build()); @@ -652,7 +652,7 @@ namespace MQTTnet.TestApp.UniversalWindows // Setup client validator. var options = new MqttServerOptions { - ConnectionValidator = c => + ConnectionValidator = new MqttServerConnectionValidatorDelegate(c => { if (c.ClientId.Length < 10) { @@ -673,7 +673,7 @@ namespace MQTTnet.TestApp.UniversalWindows } c.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; - } + }) }; }