浏览代码

Refactor subscriptions manager

release/3.x.x
Christian 6 年前
父节点
当前提交
b22b02a0b6
共有 12 个文件被更改,包括 267 次插入105 次删除
  1. +12
    -3
      Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs
  2. +26
    -0
      Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs
  3. +26
    -0
      Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs
  4. +7
    -6
      Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs
  5. +8
    -6
      Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs
  6. +24
    -39
      Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs
  7. +34
    -0
      Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs
  8. +44
    -0
      Tests/MQTTnet.Core.Tests/AsyncLockTests.cs
  9. +56
    -21
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs
  10. +6
    -6
      Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs
  11. +21
    -2
      Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs
  12. +3
    -22
      Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs

+ 12
- 3
Frameworks/MQTTnet.NetStandard/Client/MqttClient.cs 查看文件

@@ -290,18 +290,27 @@ namespace MQTTnet.Client

private Task SendAsync(MqttBasePacket packet, CancellationToken cancellationToken)
{
_sendTracker.Restart();
return _adapter.SendPacketsAsync(_options.CommunicationTimeout, new[] { packet }, cancellationToken);
return SendAsync(new[] { packet }, cancellationToken);
}

private Task SendAsync(IEnumerable<MqttBasePacket> packets, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}

_sendTracker.Restart();
return _adapter.SendPacketsAsync(_options.CommunicationTimeout, packets, cancellationToken);
}

private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttBasePacket
{
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}

_sendTracker.Restart();

ushort identifier = 0;
@@ -528,7 +537,7 @@ namespace MQTTnet.Client

private static async Task WaitForTaskAsync(Task task, Task sender)
{
if (task == sender)
if (task == sender || task == null)
{
return;
}


+ 26
- 0
Frameworks/MQTTnet.NetStandard/Internal/AsyncAutoResetEvent.cs 查看文件

@@ -0,0 +1,26 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.Internal
{
public sealed class AsyncAutoResetEvent : IDisposable
{
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(0, 1);

public Task WaitOneAsync(CancellationToken cancellationToken)
{
return _semaphore.WaitAsync(cancellationToken);
}

public void Set()
{
_semaphore.Release();
}

public void Dispose()
{
_semaphore?.Dispose();
}
}
}

+ 26
- 0
Frameworks/MQTTnet.NetStandard/Internal/AsyncLock.cs 查看文件

@@ -0,0 +1,26 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.Internal
{
public sealed class AsyncLock : IDisposable
{
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);

public Task EnterAsync(CancellationToken cancellationToken)
{
return _semaphore.WaitAsync(cancellationToken);
}

public void Exit()
{
_semaphore.Release();
}

public void Dispose()
{
_semaphore?.Dispose();
}
}
}

+ 7
- 6
Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs 查看文件

@@ -5,6 +5,7 @@ using System.Threading.Tasks;
using MQTTnet.Adapter;
using MQTTnet.Diagnostics;
using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;

@@ -13,7 +14,7 @@ namespace MQTTnet.Server
public sealed class MqttClientPendingMessagesQueue : IDisposable
{
private readonly ConcurrentQueue<MqttBasePacket> _queue = new ConcurrentQueue<MqttBasePacket>();
private readonly SemaphoreSlim _queueWaitSemaphore = new SemaphoreSlim(0);
private readonly AsyncAutoResetEvent _queueAutoResetEvent = new AsyncAutoResetEvent();
private readonly IMqttServerOptions _options;
private readonly MqttClientSession _clientSession;
private readonly IMqttNetLogger _logger;
@@ -54,7 +55,7 @@ namespace MQTTnet.Server
if (packet == null) throw new ArgumentNullException(nameof(packet));

_queue.Enqueue(packet);
_queueWaitSemaphore.Release();
_queueAutoResetEvent.Set();

_logger.Verbose<MqttClientPendingMessagesQueue>("Enqueued packet (ClientId: {0}).", _clientSession.ClientId);
}
@@ -82,7 +83,7 @@ namespace MQTTnet.Server
MqttBasePacket packet = null;
try
{
await _queueWaitSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
await _queueAutoResetEvent.WaitOneAsync(cancellationToken).ConfigureAwait(false);
if (!_queue.TryDequeue(out packet))
{
throw new InvalidOperationException(); // should not happen
@@ -120,8 +121,8 @@ namespace MQTTnet.Server
if (publishPacket.QualityOfServiceLevel > MqttQualityOfServiceLevel.AtMostOnce)
{
publishPacket.Dup = true;
_queue.Enqueue(packet);
_queueWaitSemaphore.Release();
Enqueue(publishPacket);
}
}

@@ -134,7 +135,7 @@ namespace MQTTnet.Server

public void Dispose()
{
_queueWaitSemaphore?.Dispose();
_queueAutoResetEvent?.Dispose();
}
}
}

+ 8
- 6
Frameworks/MQTTnet.NetStandard/Server/MqttClientSession.cs 查看文件

@@ -134,7 +134,7 @@ namespace MQTTnet.Server
{
if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));

var result = await SubscriptionsManager.CheckSubscriptionsAsync(applicationMessage);
var result = await SubscriptionsManager.CheckSubscriptionsAsync(applicationMessage).ConfigureAwait(false);
if (!result.IsSubscribed)
{
return;
@@ -155,10 +155,10 @@ namespace MQTTnet.Server
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

await SubscriptionsManager.SubscribeAsync(new MqttSubscribePacket
SubscriptionsManager.Subscribe(new MqttSubscribePacket
{
TopicFilters = topicFilters
}).ConfigureAwait(false);
});

await EnqueueSubscribedRetainedMessagesAsync(topicFilters).ConfigureAwait(false);
}
@@ -167,10 +167,12 @@ namespace MQTTnet.Server
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return SubscriptionsManager.UnsubscribeAsync(new MqttUnsubscribePacket
SubscriptionsManager.Unsubscribe(new MqttUnsubscribePacket
{
TopicFilters = topicFilters
});

return Task.FromResult(0);
}

public void Dispose()
@@ -288,7 +290,7 @@ namespace MQTTnet.Server

private async Task HandleIncomingSubscribePacketAsync(IMqttChannelAdapter adapter, MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
{
var subscribeResult = await SubscriptionsManager.SubscribeAsync(subscribePacket).ConfigureAwait(false);
var subscribeResult = SubscriptionsManager.Subscribe(subscribePacket);
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { subscribeResult.ResponsePacket }, cancellationToken).ConfigureAwait(false);

if (subscribeResult.CloseConnection)
@@ -302,7 +304,7 @@ namespace MQTTnet.Server

private async Task HandleIncomingUnsubscribePacketAsync(IMqttChannelAdapter adapter, MqttUnsubscribePacket unsubscribePacket, CancellationToken cancellationToken)
{
var unsubscribeResult = await SubscriptionsManager.UnsubscribeAsync(unsubscribePacket).ConfigureAwait(false);
var unsubscribeResult = SubscriptionsManager.Unsubscribe(unsubscribePacket);
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new[] { unsubscribeResult }, cancellationToken);
}



+ 24
- 39
Frameworks/MQTTnet.NetStandard/Server/MqttClientSubscriptionsManager.cs 查看文件

@@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
@@ -10,7 +11,7 @@ namespace MQTTnet.Server
{
public sealed class MqttClientSubscriptionsManager : IDisposable
{
private readonly Dictionary<string, MqttQualityOfServiceLevel> _subscriptions = new Dictionary<string, MqttQualityOfServiceLevel>();
private readonly ConcurrentDictionary<string, MqttQualityOfServiceLevel> _subscriptions = new ConcurrentDictionary<string, MqttQualityOfServiceLevel>();
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
private readonly IMqttServerOptions _options;
private readonly MqttServer _server;
@@ -23,7 +24,7 @@ namespace MQTTnet.Server
_server = server;
}

public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket)
public MqttClientSubscribeResult Subscribe(MqttSubscribePacket subscribePacket)
{
if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));

@@ -37,57 +38,41 @@ namespace MQTTnet.Server
CloseConnection = false
};

await _semaphore.WaitAsync().ConfigureAwait(false);
try
foreach (var topicFilter in subscribePacket.TopicFilters)
{
foreach (var topicFilter in subscribePacket.TopicFilters)
var interceptorContext = InterceptSubscribe(topicFilter);
if (!interceptorContext.AcceptSubscription)
{
var interceptorContext = InterceptSubscribe(topicFilter);
if (!interceptorContext.AcceptSubscription)
{
result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure);
}
else
{
result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel));
}
result.ResponsePacket.SubscribeReturnCodes.Add(MqttSubscribeReturnCode.Failure);
}
else
{
result.ResponsePacket.SubscribeReturnCodes.Add(ConvertToMaximumQoS(topicFilter.QualityOfServiceLevel));
}

if (interceptorContext.CloseConnection)
{
result.CloseConnection = true;
}
if (interceptorContext.CloseConnection)
{
result.CloseConnection = true;
}

if (interceptorContext.AcceptSubscription)
{
_subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
_server.OnClientSubscribedTopic(_clientId, topicFilter);
}
if (interceptorContext.AcceptSubscription)
{
_subscriptions[topicFilter.Topic] = topicFilter.QualityOfServiceLevel;
_server.OnClientSubscribedTopic(_clientId, topicFilter);
}
}
finally
{
_semaphore.Release();
}

return result;
}

public async Task<MqttUnsubAckPacket> UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket)
public MqttUnsubAckPacket Unsubscribe(MqttUnsubscribePacket unsubscribePacket)
{
if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket));

await _semaphore.WaitAsync().ConfigureAwait(false);
try
foreach (var topicFilter in unsubscribePacket.TopicFilters)
{
foreach (var topicFilter in unsubscribePacket.TopicFilters)
{
_subscriptions.Remove(topicFilter);
_server.OnClientUnsubscribedTopic(_clientId, topicFilter);
}
}
finally
{
_semaphore.Release();
_subscriptions.TryRemove(topicFilter, out _);
_server.OnClientUnsubscribedTopic(_clientId, topicFilter);
}

return new MqttUnsubAckPacket


+ 34
- 0
Tests/MQTTnet.Core.Tests/AsyncAutoResentEventTests.cs 查看文件

@@ -0,0 +1,34 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Internal;

namespace MQTTnet.Core.Tests
{
[TestClass]
public class AsyncAutoResetEventTests
{
[TestMethod]
public async Task AsyncAutoResetEvent()
{
var aare = new AsyncAutoResetEvent();

var increment = 0;
var globalI = 0;
#pragma warning disable 4014
Task.Run(async () =>
#pragma warning restore 4014
{
await aare.WaitOneAsync(CancellationToken.None);
globalI += increment;
});

await Task.Delay(500);
increment = 1;
aare.Set();
await Task.Delay(100);

Assert.AreEqual(1, globalI);
}
}
}

+ 44
- 0
Tests/MQTTnet.Core.Tests/AsyncLockTests.cs 查看文件

@@ -0,0 +1,44 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Internal;

namespace MQTTnet.Core.Tests
{
[TestClass]
public class AsyncLockTests
{
[TestMethod]
public void AsyncLock()
{
const int ThreadsCount = 10;

var threads = new Task[ThreadsCount];
var @lock = new AsyncLock();
var globalI = 0;
for (var i = 0; i < ThreadsCount; i++)
{
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
threads[i] = Task.Run(async () =>
{
await @lock.EnterAsync(CancellationToken.None);
try
{
var localI = globalI;
await Task.Delay(10); // Increase the chance for wrong data.
localI++;
globalI = localI;
}
finally
{
@lock.Exit();
}
});
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
}
Task.WaitAll(threads);
Assert.AreEqual(ThreadsCount, globalI);
}
}
}

+ 56
- 21
Tests/MQTTnet.Core.Tests/MqttServerTests.cs 查看文件

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Diagnostics;
@@ -58,8 +59,8 @@ namespace MQTTnet.Core.Tests
await s.StartAsync(new MqttServerOptions());

var willMessage = new MqttApplicationMessageBuilder().WithTopic("My/last/will").WithAtMostOnceQoS().Build();
var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c2 = await serverAdapter.ConnectTestClient(s, "c2", willMessage);
var c1 = await serverAdapter.ConnectTestClient("c1");
var c2 = await serverAdapter.ConnectTestClient("c2", willMessage);

c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;
await c1.SubscribeAsync(new TopicFilterBuilder().WithTopic("#").Build());
@@ -90,8 +91,8 @@ namespace MQTTnet.Core.Tests
{
await s.StartAsync(new MqttServerOptions());

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c1 = await serverAdapter.ConnectTestClient("c1");
var c2 = await serverAdapter.ConnectTestClient("c2");
c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;

var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build();
@@ -149,7 +150,7 @@ namespace MQTTnet.Core.Tests
{
await s.StartAsync(new MqttServerOptions());

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c1 = await serverAdapter.ConnectTestClient("c1");

c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;

@@ -167,6 +168,40 @@ namespace MQTTnet.Core.Tests
Assert.AreEqual(1, receivedMessagesCount);
}

[TestMethod]
public async Task MqttServer_RetainedMessagesFlow()
{
var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build();
var serverAdapter = new TestMqttServerAdapter();
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());
await s.StartAsync(new MqttServerOptions());
var c1 = await serverAdapter.ConnectTestClient("c1");
await c1.PublishAsync(retainedMessage);
Thread.Sleep(500);
await c1.DisconnectAsync();
Thread.Sleep(500);

var receivedMessages = 0;
var c2 = await serverAdapter.ConnectTestClient("c2");
c2.ApplicationMessageReceived += (_, e) =>
{
receivedMessages++;
};

for (var i = 0; i < 5; i++)
{
await c2.UnsubscribeAsync("r");
await Task.Delay(500);
Assert.AreEqual(i, receivedMessages);

await c2.SubscribeAsync("r");
await Task.Delay(500);
Assert.AreEqual(i + 1, receivedMessages);
}

await c2.DisconnectAsync();
}

[TestMethod]
public async Task MqttServer_NoRetainedMessage()
{
@@ -179,11 +214,11 @@ namespace MQTTnet.Core.Tests
{
await s.StartAsync(new MqttServerOptions());

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c1 = await serverAdapter.ConnectTestClient("c1");
await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).Build());
await c1.DisconnectAsync();

var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c2 = await serverAdapter.ConnectTestClient("c2");
c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;
await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build());

@@ -208,11 +243,11 @@ namespace MQTTnet.Core.Tests
{
await s.StartAsync(new MqttServerOptions());

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c1 = await serverAdapter.ConnectTestClient("c1");
await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build());
await c1.DisconnectAsync();

var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c2 = await serverAdapter.ConnectTestClient("c2");
c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;
await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build());

@@ -237,14 +272,14 @@ namespace MQTTnet.Core.Tests
{
await s.StartAsync(new MqttServerOptions());

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c1 = await serverAdapter.ConnectTestClient("c1");
await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build());
await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[0]).WithRetainFlag().Build());
await c1.DisconnectAsync();
var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c2 = await serverAdapter.ConnectTestClient("c2");
c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;
await Task.Delay(200);
await c2.SubscribeAsync(new TopicFilter("retained", MqttQualityOfServiceLevel.AtMostOnce));
await Task.Delay(500);
@@ -271,7 +306,7 @@ namespace MQTTnet.Core.Tests

await s.StartAsync(options);

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c1 = await serverAdapter.ConnectTestClient("c1");

await c1.PublishAndWaitForAsync(s, new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build());
await c1.DisconnectAsync();
@@ -291,7 +326,7 @@ namespace MQTTnet.Core.Tests
var options = new MqttServerOptions { Storage = storage };
await s.StartAsync(options);

var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c2 = await serverAdapter.ConnectTestClient("c2");
c2.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;
await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("retained").Build());

@@ -322,8 +357,8 @@ namespace MQTTnet.Core.Tests

await s.StartAsync(options);

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c1 = await serverAdapter.ConnectTestClient("c1");
var c2 = await serverAdapter.ConnectTestClient("c2");
await c2.SubscribeAsync(new TopicFilterBuilder().WithTopic("test").Build());

var isIntercepted = false;
@@ -357,8 +392,8 @@ namespace MQTTnet.Core.Tests
{
await s.StartAsync(new MqttServerOptions());

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c1 = await serverAdapter.ConnectTestClient("c1");
var c2 = await serverAdapter.ConnectTestClient("c2");

c1.ApplicationMessageReceived += (_, e) =>
{
@@ -412,8 +447,8 @@ namespace MQTTnet.Core.Tests
{
await s.StartAsync(new MqttServerOptions());

var c1 = await serverAdapter.ConnectTestClient(s, "c1");
var c2 = await serverAdapter.ConnectTestClient(s, "c2");
var c1 = await serverAdapter.ConnectTestClient("c1");
var c2 = await serverAdapter.ConnectTestClient("c2");

c1.ApplicationMessageReceived += (_, __) => receivedMessagesCount++;



+ 6
- 6
Tests/MQTTnet.Core.Tests/MqttSubscriptionsManagerTests.cs 查看文件

@@ -18,7 +18,7 @@ namespace MQTTnet.Core.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.SubscribeAsync(sp).Wait();
sm.Subscribe(sp);

var pp = new MqttApplicationMessage
{
@@ -39,7 +39,7 @@ namespace MQTTnet.Core.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtMostOnce));

sm.SubscribeAsync(sp).Wait();
sm.Subscribe(sp);

var pp = new MqttApplicationMessage
{
@@ -61,7 +61,7 @@ namespace MQTTnet.Core.Tests
sp.TopicFilters.Add(new TopicFilter("#", MqttQualityOfServiceLevel.AtMostOnce));
sp.TopicFilters.Add(new TopicFilter("A/B/C", MqttQualityOfServiceLevel.AtLeastOnce));

sm.SubscribeAsync(sp).Wait();
sm.Subscribe(sp);

var pp = new MqttApplicationMessage
{
@@ -82,7 +82,7 @@ namespace MQTTnet.Core.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.SubscribeAsync(sp).Wait();
sm.Subscribe(sp);

var pp = new MqttApplicationMessage
{
@@ -101,7 +101,7 @@ namespace MQTTnet.Core.Tests
var sp = new MqttSubscribePacket();
sp.TopicFilters.Add(new TopicFilterBuilder().WithTopic("A/B/C").Build());

sm.SubscribeAsync(sp).Wait();
sm.Subscribe(sp);

var pp = new MqttApplicationMessage
{
@@ -113,7 +113,7 @@ namespace MQTTnet.Core.Tests

var up = new MqttUnsubscribePacket();
up.TopicFilters.Add("A/B/C");
sm.UnsubscribeAsync(up).Wait();
sm.Unsubscribe(up);

Assert.IsFalse(sm.CheckSubscriptionsAsync(pp).Result.IsSubscribed);
}


+ 21
- 2
Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs 查看文件

@@ -43,11 +43,30 @@ namespace MQTTnet.Core.Tests
return Task.FromResult(0);
}

public Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken)
public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfPartnerIsNull();

return Task.Run(() =>
if (timeout > TimeSpan.Zero)
{
using (var timeoutCts = new CancellationTokenSource(timeout))
using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken))
{
return await Task.Run(() =>
{
try
{
return _incomingPackets.Take(cts.Token);
}
catch
{
return null;
}
}, cts.Token);
}
}

return await Task.Run(() =>
{
try
{


+ 3
- 22
Tests/MQTTnet.Core.Tests/TestMqttServerAdapter.cs 查看文件

@@ -1,4 +1,5 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Adapter;
using MQTTnet.Client;
@@ -11,7 +12,7 @@ namespace MQTTnet.Core.Tests
{
public event EventHandler<MqttServerAdapterClientAcceptedEventArgs> ClientAccepted;
public async Task<IMqttClient> ConnectTestClient(IMqttServer server, string clientId, MqttApplicationMessage willMessage = null)
public async Task<IMqttClient> ConnectTestClient(string clientId, MqttApplicationMessage willMessage = null)
{
var adapterA = new TestMqttCommunicationAdapter();
var adapterB = new TestMqttCommunicationAdapter();
@@ -22,8 +23,6 @@ namespace MQTTnet.Core.Tests
new TestMqttCommunicationAdapterFactory(adapterA),
new MqttNetLogger());

var connected = WaitForClientToConnect(server, clientId);

FireClientAcceptedEvent(adapterB);

var options = new MqttClientOptions
@@ -34,29 +33,11 @@ namespace MQTTnet.Core.Tests
};

await client.ConnectAsync(options);
await connected;
SpinWait.SpinUntil(() => client.IsConnected);

return client;
}
private static Task WaitForClientToConnect(IMqttServer s, string clientId)
{
var tcs = new TaskCompletionSource<object>();

void Handler(object sender, Server.MqttClientConnectedEventArgs args)
{
if (args.Client.ClientId == clientId)
{
s.ClientConnected -= Handler;
tcs.SetResult(null);
}
}

s.ClientConnected += Handler;

return tcs.Task;
}

private void FireClientAcceptedEvent(IMqttChannelAdapter adapter)
{
ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(adapter));


正在加载...
取消
保存