diff --git a/Source/MQTTnet/Server/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/MqttClientSessionsManager.cs index db70e95..1e7c1b6 100644 --- a/Source/MQTTnet/Server/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSessionsManager.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO; using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; @@ -29,6 +30,27 @@ namespace MQTTnet.Server private readonly IMqttServerOptions _options; private readonly IMqttNetChildLogger _logger; + public static class TestLogger + { + public static void WriteLine(string message) + { + var path = @"c:\temp\test1.txt"; + FileStream logFile; + if (!System.IO.File.Exists(path)) + logFile = System.IO.File.Create(path); + else + logFile = System.IO.File.Open(path, FileMode.Append); + + using (var writer = new System.IO.StreamWriter(logFile)) + { + writer.WriteLine($"{DateTime.Now} - {message}"); + } + + logFile.Dispose(); + } + } + + public MqttClientSessionsManager( IMqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, @@ -36,6 +58,7 @@ namespace MQTTnet.Server MqttServerEventDispatcher eventDispatcher, IMqttNetChildLogger logger) { + TestLogger.WriteLine("Newly new"); _cancellationToken = cancellationToken; if (logger == null) throw new ArgumentNullException(nameof(logger)); @@ -48,11 +71,13 @@ namespace MQTTnet.Server public void Start() { + TestLogger.WriteLine("Start"); Task.Run(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken).Forget(_logger); } public async Task StopAsync() { + TestLogger.WriteLine("Stop"); foreach (var connection in _connections.Values) { await connection.StopAsync().ConfigureAwait(false); @@ -66,6 +91,7 @@ namespace MQTTnet.Server public Task> GetClientStatusAsync() { + TestLogger.WriteLine("Status"); var result = new List(); foreach (var connection in _connections.Values) @@ -85,6 +111,7 @@ namespace MQTTnet.Server public Task> GetSessionStatusAsync() { + TestLogger.WriteLine("Session"); var result = new List(); foreach (var session in _sessions.Values) @@ -100,6 +127,7 @@ namespace MQTTnet.Server public void DispatchApplicationMessage(MqttApplicationMessage applicationMessage, MqttClientConnection sender) { + TestLogger.WriteLine("Message"); if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage)); _messageQueue.Enqueue(new MqttEnqueuedApplicationMessage(applicationMessage, sender)); @@ -107,6 +135,7 @@ namespace MQTTnet.Server public Task SubscribeAsync(string clientId, ICollection topicFilters) { + TestLogger.WriteLine("sub"); if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -120,6 +149,7 @@ namespace MQTTnet.Server public Task UnsubscribeAsync(string clientId, IEnumerable topicFilters) { + TestLogger.WriteLine("unsub"); if (clientId == null) throw new ArgumentNullException(nameof(clientId)); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); @@ -133,6 +163,7 @@ namespace MQTTnet.Server public async Task DeleteSessionAsync(string clientId) { + TestLogger.WriteLine("Delete"); if (_connections.TryGetValue(clientId, out var connection)) { await connection.StopAsync().ConfigureAwait(false); @@ -147,11 +178,13 @@ namespace MQTTnet.Server public void Dispose() { + TestLogger.WriteLine("byebye"); _messageQueue?.Dispose(); } private async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken) { + TestLogger.WriteLine("queue"); while (!cancellationToken.IsCancellationRequested) { try @@ -170,6 +203,7 @@ namespace MQTTnet.Server private async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken) { + TestLogger.WriteLine("process message"); try { if (cancellationToken.IsCancellationRequested) @@ -178,6 +212,7 @@ namespace MQTTnet.Server } var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false); + TestLogger.WriteLine("dequeued"); var queuedApplicationMessage = dequeueResult.Item; var sender = queuedApplicationMessage.Sender; @@ -209,6 +244,7 @@ namespace MQTTnet.Server await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false); } + TestLogger.WriteLine($"sessions: {_sessions.Count}"); foreach (var clientSession in _sessions.Values) { clientSession.EnqueueApplicationMessage( @@ -219,18 +255,23 @@ namespace MQTTnet.Server } catch (OperationCanceledException) { + TestLogger.WriteLine($"no queue"); } catch (Exception exception) { + TestLogger.WriteLine($"no queue {exception}"); _logger.Error(exception, "Unhandled exception while processing next queued application message."); } } private async Task HandleClientAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken) { + TestLogger.WriteLine($"handle"); var disconnectType = MqttClientDisconnectType.NotClean; string clientId = null; + var ok = true; + try { var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false); @@ -241,11 +282,14 @@ namespace MQTTnet.Server } clientId = connectPacket.ClientId; + TestLogger.WriteLine($"validating {clientId}"); var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false); if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) { + TestLogger.WriteLine($"{clientId} not good"); + ok = false; // Send failure response here without preparing a session. The result for a successful connect // will be sent from the session itself. var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); @@ -254,42 +298,53 @@ namespace MQTTnet.Server return; } + TestLogger.WriteLine($"{clientId} good"); + var connection = await CreateConnectionAsync(connectPacket, connectionValidatorContext, channelAdapter).ConfigureAwait(false); await _eventDispatcher.HandleClientConnectedAsync(clientId).ConfigureAwait(false); disconnectType = await connection.RunAsync().ConfigureAwait(false); + + TestLogger.WriteLine($"{clientId} all good"); } catch (OperationCanceledException) { + TestLogger.WriteLine($"no"); } catch (Exception exception) { + TestLogger.WriteLine($"no {exception}"); _logger.Error(exception, exception.Message); } finally { - if (clientId != null) - { - _connections.TryRemove(clientId, out _); - - if (!_options.EnablePersistentSessions) + if (ok) + { + TestLogger.WriteLine($"finally {clientId}"); + if (clientId != null) { - await DeleteSessionAsync(clientId).ConfigureAwait(false); + _connections.TryRemove(clientId, out _); + + if (!_options.EnablePersistentSessions) + { + await DeleteSessionAsync(clientId).ConfigureAwait(false); + } } - } - await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); + await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false); - if (clientId != null) - { - await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + if (clientId != null) + { + await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false); + } } } } private async Task ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter) { + TestLogger.WriteLine("validate"); var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary()); var connectionValidator = _options.ConnectionValidator; @@ -318,6 +373,7 @@ namespace MQTTnet.Server private async Task CreateConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter) { + TestLogger.WriteLine("create"); await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false); try { @@ -364,6 +420,7 @@ namespace MQTTnet.Server private async Task InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage) { + TestLogger.WriteLine("intercept"); var interceptor = _options.ApplicationMessageInterceptor; if (interceptor == null) { @@ -392,6 +449,7 @@ namespace MQTTnet.Server private async Task TryCleanupChannelAsync(IMqttChannelAdapter channelAdapter) { + TestLogger.WriteLine("clean"); try { await channelAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); diff --git a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs index c84a018..3b166ba 100644 --- a/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/MqttClientSubscriptionsManager.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Threading.Tasks; using MQTTnet.Packets; using MQTTnet.Protocol; +using static MQTTnet.Server.MqttClientSessionsManager; namespace MQTTnet.Server { @@ -16,6 +17,7 @@ namespace MQTTnet.Server public MqttClientSubscriptionsManager(MqttClientSession clientSession, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions) { + TestLogger.WriteLine("sub manager"); _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession)); // TODO: Consider removing the server options here and build a new class "ISubscriptionInterceptor" and just pass it. The instance is generated in the root server class upon start. @@ -25,6 +27,7 @@ namespace MQTTnet.Server public async Task SubscribeAsync(MqttSubscribePacket subscribePacket, MqttConnectPacket connectPacket) { + TestLogger.WriteLine("sub1"); if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket)); if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); @@ -76,6 +79,7 @@ namespace MQTTnet.Server public async Task SubscribeAsync(IEnumerable topicFilters) { + TestLogger.WriteLine("sub2"); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); foreach (var topicFilter in topicFilters) @@ -100,6 +104,7 @@ namespace MQTTnet.Server public async Task UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket) { + TestLogger.WriteLine("unsub1"); if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket)); var unsubAckPacket = new MqttUnsubAckPacket @@ -132,6 +137,7 @@ namespace MQTTnet.Server public Task UnsubscribeAsync(IEnumerable topicFilters) { + TestLogger.WriteLine("unsub2"); if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); lock (_subscriptions) @@ -147,6 +153,7 @@ namespace MQTTnet.Server public CheckSubscriptionsResult CheckSubscriptions(string topic, MqttQualityOfServiceLevel qosLevel) { + TestLogger.WriteLine("check"); var qosLevels = new HashSet(); lock (_subscriptions) diff --git a/Source/MQTTnet/Server/MqttServerEventDispatcher.cs b/Source/MQTTnet/Server/MqttServerEventDispatcher.cs index e6e608a..d230eac 100644 --- a/Source/MQTTnet/Server/MqttServerEventDispatcher.cs +++ b/Source/MQTTnet/Server/MqttServerEventDispatcher.cs @@ -2,6 +2,7 @@ using System.Threading.Tasks; using MQTTnet.Client.Receiving; using MQTTnet.Diagnostics; +using static MQTTnet.Server.MqttClientSessionsManager; namespace MQTTnet.Server { @@ -55,6 +56,7 @@ namespace MQTTnet.Server public Task HandleClientSubscribedTopicAsync(string clientId, TopicFilter topicFilter) { + TestLogger.WriteLine("handle sub"); var handler = ClientSubscribedTopicHandler; if (handler == null) { diff --git a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs index 8111bd0..c4a5d27 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs @@ -6,6 +6,7 @@ using MQTTnet.Tests.Mockups; using MQTTnet.Client; using MQTTnet.Protocol; using MQTTnet.Server; +using System.Threading; namespace MQTTnet.Tests { @@ -55,10 +56,10 @@ namespace MQTTnet.Tests var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1")); - await Task.Delay(500); + await Task.Delay(1000); var clientStatus = await server.GetClientStatusAsync(); - + Assert.AreEqual(1, clientStatus.Count); Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1")); diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs index b2b3b70..5353ebd 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -917,6 +917,109 @@ namespace MQTTnet.Tests } } + + private Dictionary _connected; + private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs) + { + if (_connected.ContainsKey(eventArgs.ClientId)) + { + eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; + return; + } + _connected[eventArgs.ClientId] = true; + eventArgs.ReasonCode = MqttConnectReasonCode.Success; + return; + } + + [TestMethod] + public async Task Same_Client_Id_Refuse_Connection() + { + using (var testEnvironment = new TestEnvironment()) + { + _connected = new Dictionary(); + var options = new MqttServerOptionsBuilder(); + options.WithConnectionValidator(e => ConnectionValidationHandler(e)); + var server = await testEnvironment.StartServerAsync(options); + + var events = new List(); + + server.ClientConnectedHandler = new MqttServerClientConnectedHandlerDelegate(_ => + { + lock (events) + { + events.Add("c"); + } + }); + + server.ClientDisconnectedHandler = new MqttServerClientDisconnectedHandlerDelegate(_ => + { + lock (events) + { + events.Add("d"); + } + }); + + var clientOptions = new MqttClientOptionsBuilder() + .WithClientId("same_id"); + + // c + var c1 = await testEnvironment.ConnectClientAsync(clientOptions); + + c1.UseDisconnectedHandler(_ => + { + lock (events) + { + events.Add("x"); + } + }); + + + c1.UseApplicationMessageReceivedHandler(_ => + { + lock (events) + { + events.Add("r"); + } + + }); + + c1.SubscribeAsync("topic").Wait(); + + await Task.Delay(500); + + c1.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + + var flow = string.Join(string.Empty, events); + Assert.AreEqual("cr", flow); + + try + { + await testEnvironment.ConnectClientAsync(clientOptions); + Assert.Fail("same id connection is expected to fail"); + } + catch + { + //same id connection is expected to fail + } + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cr", flow); + + c1.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("crr", flow); + + } + } + [TestMethod] public async Task Same_Client_Id_Connect_Disconnect_Event_Order() { @@ -956,17 +1059,40 @@ namespace MQTTnet.Tests // dc var c2 = await testEnvironment.ConnectClientAsync(clientOptions); + c2.UseApplicationMessageReceivedHandler(_ => + { + lock (events) + { + events.Add("r"); + } + + }); + c2.SubscribeAsync("topic").Wait(); + await Task.Delay(500); flow = string.Join(string.Empty, events); Assert.AreEqual("cdc", flow); + // r + c2.PublishAsync("topic").Wait(); + + await Task.Delay(500); + + flow = string.Join(string.Empty, events); + Assert.AreEqual("cdcr", flow); + + // nothing + + Assert.AreEqual(false, c1.IsConnected); await c1.DisconnectAsync(); + Assert.AreEqual (false, c1.IsConnected); await Task.Delay(500); // d + Assert.AreEqual(true, c2.IsConnected); await c2.DisconnectAsync(); await Task.Delay(500); @@ -974,7 +1100,7 @@ namespace MQTTnet.Tests await server.StopAsync(); flow = string.Join(string.Empty, events); - Assert.AreEqual("cdcd", flow); + Assert.AreEqual("cdcrd", flow); } }