Browse Source

Merge pull request #786 from llrosa/debugFixClientId

Fix for #785 - Skip client disconnection process if unauthorized
release/3.x.x
Christian 5 years ago
committed by GitHub
parent
commit
0950dfbb8b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 146 additions and 14 deletions
  1. +16
    -11
      Source/MQTTnet/Server/MqttClientSessionsManager.cs
  2. +3
    -2
      Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs
  3. +127
    -1
      Tests/MQTTnet.Core.Tests/Server_Tests.cs

+ 16
- 11
Source/MQTTnet/Server/MqttClientSessionsManager.cs View File

@@ -230,6 +230,7 @@ namespace MQTTnet.Server
{ {
var disconnectType = MqttClientDisconnectType.NotClean; var disconnectType = MqttClientDisconnectType.NotClean;
string clientId = null; string clientId = null;
var clientWasConnected = true;


try try
{ {
@@ -246,6 +247,7 @@ namespace MQTTnet.Server


if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success) if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success)
{ {
clientWasConnected = false;
// Send failure response here without preparing a session. The result for a successful connect // Send failure response here without preparing a session. The result for a successful connect
// will be sent from the session itself. // will be sent from the session itself.
var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext); var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext);
@@ -269,21 +271,24 @@ namespace MQTTnet.Server
} }
finally finally
{ {
if (clientId != null)
{
_connections.TryRemove(clientId, out _);

if (!_options.EnablePersistentSessions)
if (clientWasConnected)
{
if (clientId != null)
{ {
await DeleteSessionAsync(clientId).ConfigureAwait(false);
_connections.TryRemove(clientId, out _);

if (!_options.EnablePersistentSessions)
{
await DeleteSessionAsync(clientId).ConfigureAwait(false);
}
} }
}


await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false);
await TryCleanupChannelAsync(channelAdapter).ConfigureAwait(false);


if (clientId != null)
{
await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
if (clientId != null)
{
await _eventDispatcher.TryHandleClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
}
} }
} }
} }


+ 3
- 2
Tests/MQTTnet.Core.Tests/Server_Status_Tests.cs View File

@@ -6,6 +6,7 @@ using MQTTnet.Tests.Mockups;
using MQTTnet.Client; using MQTTnet.Client;
using MQTTnet.Protocol; using MQTTnet.Protocol;
using MQTTnet.Server; using MQTTnet.Server;
using System.Threading;


namespace MQTTnet.Tests namespace MQTTnet.Tests
{ {
@@ -55,10 +56,10 @@ namespace MQTTnet.Tests


var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1")); var c1 = await testEnvironment.ConnectClientAsync(new MqttClientOptionsBuilder().WithClientId("client1"));
await Task.Delay(500);
await Task.Delay(1000);


var clientStatus = await server.GetClientStatusAsync(); var clientStatus = await server.GetClientStatusAsync();
Assert.AreEqual(1, clientStatus.Count); Assert.AreEqual(1, clientStatus.Count);
Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1")); Assert.IsTrue(clientStatus.Any(s => s.ClientId == "client1"));


+ 127
- 1
Tests/MQTTnet.Core.Tests/Server_Tests.cs View File

@@ -917,6 +917,109 @@ namespace MQTTnet.Tests
} }
} }



private Dictionary<string, bool> _connected;
private void ConnectionValidationHandler(MqttConnectionValidatorContext eventArgs)
{
if (_connected.ContainsKey(eventArgs.ClientId))
{
eventArgs.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword;
return;
}
_connected[eventArgs.ClientId] = true;
eventArgs.ReasonCode = MqttConnectReasonCode.Success;
return;
}

[TestMethod]
public async Task Same_Client_Id_Refuse_Connection()
{
using (var testEnvironment = new TestEnvironment())
{
testEnvironment.IgnoreClientLogErrors = true;

_connected = new Dictionary<string, bool>();
var options = new MqttServerOptionsBuilder();
options.WithConnectionValidator(e => ConnectionValidationHandler(e));
var server = await testEnvironment.StartServerAsync(options);

var events = new List<string>();

server.ClientConnectedHandler = new MqttServerClientConnectedHandlerDelegate(_ =>
{
lock (events)
{
events.Add("c");
}
});

server.ClientDisconnectedHandler = new MqttServerClientDisconnectedHandlerDelegate(_ =>
{
lock (events)
{
events.Add("d");
}
});

var clientOptions = new MqttClientOptionsBuilder()
.WithClientId("same_id");

// c
var c1 = await testEnvironment.ConnectClientAsync(clientOptions);

c1.UseDisconnectedHandler(_ =>
{
lock (events)
{
events.Add("x");
}
});


c1.UseApplicationMessageReceivedHandler(_ =>
{
lock (events)
{
events.Add("r");
}

});

c1.SubscribeAsync("topic").Wait();

await Task.Delay(500);

c1.PublishAsync("topic").Wait();

await Task.Delay(500);

var flow = string.Join(string.Empty, events);
Assert.AreEqual("cr", flow);

try
{
await testEnvironment.ConnectClientAsync(clientOptions);
Assert.Fail("same id connection is expected to fail");
}
catch
{
//same id connection is expected to fail
}

await Task.Delay(500);

flow = string.Join(string.Empty, events);
Assert.AreEqual("cr", flow);

c1.PublishAsync("topic").Wait();

await Task.Delay(500);

flow = string.Join(string.Empty, events);
Assert.AreEqual("crr", flow);
}
}

[TestMethod] [TestMethod]
public async Task Same_Client_Id_Connect_Disconnect_Event_Order() public async Task Same_Client_Id_Connect_Disconnect_Event_Order()
{ {
@@ -956,17 +1059,40 @@ namespace MQTTnet.Tests
// dc // dc
var c2 = await testEnvironment.ConnectClientAsync(clientOptions); var c2 = await testEnvironment.ConnectClientAsync(clientOptions);


c2.UseApplicationMessageReceivedHandler(_ =>
{
lock (events)
{
events.Add("r");
}

});
c2.SubscribeAsync("topic").Wait();

await Task.Delay(500); await Task.Delay(500);


flow = string.Join(string.Empty, events); flow = string.Join(string.Empty, events);
Assert.AreEqual("cdc", flow); Assert.AreEqual("cdc", flow);


// r
c2.PublishAsync("topic").Wait();

await Task.Delay(500);

flow = string.Join(string.Empty, events);
Assert.AreEqual("cdcr", flow);


// nothing // nothing

Assert.AreEqual(false, c1.IsConnected);
await c1.DisconnectAsync(); await c1.DisconnectAsync();
Assert.AreEqual (false, c1.IsConnected);


await Task.Delay(500); await Task.Delay(500);


// d // d
Assert.AreEqual(true, c2.IsConnected);
await c2.DisconnectAsync(); await c2.DisconnectAsync();


await Task.Delay(500); await Task.Delay(500);
@@ -974,7 +1100,7 @@ namespace MQTTnet.Tests
await server.StopAsync(); await server.StopAsync();


flow = string.Join(string.Empty, events); flow = string.Join(string.Empty, events);
Assert.AreEqual("cdcd", flow);
Assert.AreEqual("cdcrd", flow);
} }
} }




Loading…
Cancel
Save