Browse Source

Add cancellation token to adapter

release/3.x.x
Christian Kratky 7 years ago
parent
commit
e7c8d1c1c1
13 changed files with 215 additions and 145 deletions
  1. +3
    -2
      MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs
  2. +35
    -18
      MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs
  3. +3
    -2
      MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs
  4. +4
    -4
      MQTTnet.Core/Client/IMqttClient.cs
  5. +85
    -91
      MQTTnet.Core/Client/MqttClient.cs
  6. +24
    -1
      MQTTnet.Core/Client/MqttClientExtensions.cs
  7. +1
    -1
      MQTTnet.Core/Server/MqttClientMessageQueue.cs
  8. +8
    -8
      MQTTnet.Core/Server/MqttClientSession.cs
  9. +4
    -4
      MQTTnet.Core/Server/MqttClientSessionsManager.cs
  10. +41
    -3
      Tests/MQTTnet.Core.Tests/ExtensionTests.cs
  11. +2
    -1
      Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs
  12. +2
    -2
      Tests/MQTTnet.Core.Tests/MqttServerTests.cs
  13. +3
    -8
      Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs

+ 3
- 2
MQTTnet.Core/Adapter/IMqttCommunicationAdapter.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Core.Client;
using MQTTnet.Core.Packets;
@@ -15,8 +16,8 @@ namespace MQTTnet.Core.Adapter

Task DisconnectAsync(TimeSpan timeout);

Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets);
Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable<MqttBasePacket> packets);

Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout);
Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken);
}
}

+ 35
- 18
MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Core.Channel;
using MQTTnet.Core.Client;
@@ -30,7 +31,11 @@ namespace MQTTnet.Core.Adapter
{
try
{
await _channel.ConnectAsync(options).TimeoutAfter(timeout);
await _channel.ConnectAsync(options).TimeoutAfter(timeout).ConfigureAwait(false);
}
catch (TaskCanceledException)
{
throw;
}
catch (MqttCommunicationTimedOutException)
{
@@ -52,6 +57,10 @@ namespace MQTTnet.Core.Adapter
{
await _channel.DisconnectAsync().TimeoutAfter(timeout).ConfigureAwait(false);
}
catch (TaskCanceledException)
{
throw;
}
catch (MqttCommunicationTimedOutException)
{
throw;
@@ -66,7 +75,7 @@ namespace MQTTnet.Core.Adapter
}
}

public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets)
public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable<MqttBasePacket> packets)
{
try
{
@@ -77,20 +86,24 @@ namespace MQTTnet.Core.Adapter
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout);

var writeBuffer = PacketSerializer.Serialize(packet);
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length));
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false), cancellationToken);
}
}

await _sendTask; // configure await false geneates stackoverflow
await _sendTask; // configure await false generates stackoverflow

if (timeout > TimeSpan.Zero)
{
await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false);
await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false);
}
else
{
await _channel.SendStream.FlushAsync().ConfigureAwait(false);
}
await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
}
catch (TaskCanceledException)
{
throw;
}
catch (MqttCommunicationTimedOutException)
{
@@ -106,18 +119,23 @@ namespace MQTTnet.Core.Adapter
}
}

public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout)
public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
try
{
ReceivedMqttPacket receivedMqttPacket;
if (timeout > TimeSpan.Zero)
{
receivedMqttPacket = await ReceiveAsync(_channel.RawReceiveStream).TimeoutAfter(timeout).ConfigureAwait(false);
receivedMqttPacket = await ReceiveAsync(_channel.RawReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false);
}
else
{
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false);
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false);
}

if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}

var packet = PacketSerializer.Deserialize(receivedMqttPacket);
@@ -129,6 +147,10 @@ namespace MQTTnet.Core.Adapter
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet);
return packet;
}
catch (TaskCanceledException)
{
throw;
}
catch (MqttCommunicationTimedOutException)
{
throw;
@@ -143,9 +165,9 @@ namespace MQTTnet.Core.Adapter
}
}

private static async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream)
private static async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream, CancellationToken cancellationToken)
{
var header = MqttPacketReader.ReadHeaderFromSource(stream);
var header = MqttPacketReader.ReadHeaderFromSource(stream, cancellationToken);

if (header.BodyLength == 0)
{
@@ -157,15 +179,10 @@ namespace MQTTnet.Core.Adapter
var offset = 0;
do
{
var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset).ConfigureAwait(false);
var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset, cancellationToken).ConfigureAwait(false);
offset += readBytesCount;
} while (offset < header.BodyLength);

if (offset > header.BodyLength)
{
throw new MqttCommunicationException($"Read more body bytes than required ({offset}/{header.BodyLength}).");
}

return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length));
}
}

+ 3
- 2
MQTTnet.Core/Adapter/MqttCommunicationAdapterExtensions.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Core.Packets;

@@ -6,9 +7,9 @@ namespace MQTTnet.Core.Adapter
{
public static class MqttCommunicationAdapterExtensions
{
public static Task SendPacketsAsync(this IMqttCommunicationAdapter adapter, TimeSpan timeout, params MqttBasePacket[] packets)
public static Task SendPacketsAsync(this IMqttCommunicationAdapter adapter, TimeSpan timeout, CancellationToken cancellationToken, params MqttBasePacket[] packets)
{
return adapter.SendPacketsAsync(timeout, packets);
return adapter.SendPacketsAsync(timeout, cancellationToken, packets);
}
}
}

+ 4
- 4
MQTTnet.Core/Client/IMqttClient.cs View File

@@ -15,10 +15,10 @@ namespace MQTTnet.Core.Client

Task ConnectAsync(MqttApplicationMessage willApplicationMessage = null);
Task DisconnectAsync();

Task<IList<MqttSubscribeResult>> SubscribeAsync(IEnumerable<TopicFilter> topicFilters);
Task UnsubscribeAsync(IEnumerable<string> topicFilters);

Task PublishAsync(IEnumerable<MqttApplicationMessage> applicationMessages);
Task<IList<MqttSubscribeResult>> SubscribeAsync(IList<TopicFilter> topicFilters);
Task<IList<MqttSubscribeResult>> SubscribeAsync(params TopicFilter[] topicFilters);
Task Unsubscribe(IList<string> topicFilters);
Task Unsubscribe(params string[] topicFilters);
}
}

+ 85
- 91
MQTTnet.Core/Client/MqttClient.cs View File

@@ -15,12 +15,10 @@ namespace MQTTnet.Core.Client
public class MqttClient : IMqttClient
{
private readonly HashSet<ushort> _unacknowledgedPublishPackets = new HashSet<ushort>();

private readonly MqttPacketDispatcher _packetDispatcher = new MqttPacketDispatcher();
private readonly MqttClientOptions _options;
private readonly IMqttCommunicationAdapter _adapter;

private bool _disconnectedEventSuspended;
private int _latestPacketIdentifier;
private CancellationTokenSource _cancellationTokenSource;

@@ -33,30 +31,27 @@ namespace MQTTnet.Core.Client
}

public event EventHandler Connected;

public event EventHandler Disconnected;

public event EventHandler<MqttApplicationMessageReceivedEventArgs> ApplicationMessageReceived;

public bool IsConnected { get; private set; }
public bool IsConnected => _cancellationTokenSource != null && !_cancellationTokenSource.IsCancellationRequested;

public async Task ConnectAsync(MqttApplicationMessage willApplicationMessage = null)
{
MqttTrace.Verbose(nameof(MqttClient), "Trying to connect.");

if (IsConnected)
{
throw new MqttProtocolViolationException("It is not allowed to connect with a server after the connection is established.");
}
ThrowIfConnected("It is not allowed to connect with a server after the connection is established.");

try
{
_disconnectedEventSuspended = false;

MqttTrace.Verbose(nameof(MqttClient), "Trying to connect with server.");
await _adapter.ConnectAsync(_options.DefaultCommunicationTimeout, _options).ConfigureAwait(false);

MqttTrace.Verbose(nameof(MqttClient), "Connection with server established.");

_cancellationTokenSource = new CancellationTokenSource();
_latestPacketIdentifier = 0;
_packetDispatcher.Reset();

StartReceivePackets(_cancellationTokenSource.Token);

var connectPacket = new MqttConnectPacket
{
ClientId = _options.ClientId,
@@ -67,28 +62,19 @@ namespace MQTTnet.Core.Client
WillMessage = willApplicationMessage
};

_cancellationTokenSource = new CancellationTokenSource();
_latestPacketIdentifier = 0;
_packetDispatcher.Reset();

StartReceivePackets();

var response = await SendAndReceiveAsync<MqttConnAckPacket>(connectPacket).ConfigureAwait(false);
if (response.ConnectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
{
await DisconnectInternalAsync().ConfigureAwait(false);
throw new MqttConnectingFailedException(response.ConnectReturnCode);
}

MqttTrace.Verbose(nameof(MqttClient), "MQTT connection with server established.");
Connected?.Invoke(this, EventArgs.Empty);

if (_options.KeepAlivePeriod != TimeSpan.Zero)
{
StartSendKeepAliveMessages();
StartSendKeepAliveMessages(_cancellationTokenSource.Token);
}

MqttTrace.Verbose(nameof(MqttClient), "MQTT connection with server established.");

IsConnected = true;
Connected?.Invoke(this, EventArgs.Empty);
}
catch (Exception)
{
@@ -114,56 +100,41 @@ namespace MQTTnet.Core.Client
}
}

public Task<IList<MqttSubscribeResult>> SubscribeAsync(params TopicFilter[] topicFilters)
public async Task<IList<MqttSubscribeResult>> SubscribeAsync(IEnumerable<TopicFilter> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return SubscribeAsync(topicFilters.ToList());
}

public async Task<IList<MqttSubscribeResult>> SubscribeAsync(IList<TopicFilter> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
if (!topicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.8.3-3].");

ThrowIfNotConnected();

var subscribePacket = new MqttSubscribePacket
{
PacketIdentifier = GetNewPacketIdentifier(),
TopicFilters = topicFilters
TopicFilters = topicFilters.ToList()
};

var response = await SendAndReceiveAsync<MqttSubAckPacket>(subscribePacket).ConfigureAwait(false);

if (response.SubscribeReturnCodes.Count != topicFilters.Count)
if (response.SubscribeReturnCodes.Count != subscribePacket.TopicFilters.Count)
{
throw new MqttProtocolViolationException("The return codes are not matching the topic filters [MQTT-3.9.3-1].");
}

return topicFilters.Select((t, i) => new MqttSubscribeResult(t, response.SubscribeReturnCodes[i])).ToList();
return subscribePacket.TopicFilters.Select((t, i) => new MqttSubscribeResult(t, response.SubscribeReturnCodes[i])).ToList();
}

public Task Unsubscribe(params string[] topicFilters)
public async Task UnsubscribeAsync(IEnumerable<string> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return Unsubscribe(topicFilters.ToList());
}

public Task Unsubscribe(IList<string> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
if (!topicFilters.Any()) throw new MqttProtocolViolationException("At least one topic filter must be set [MQTT-3.10.3-2].");
ThrowIfNotConnected();

var unsubscribePacket = new MqttUnsubscribePacket
{
PacketIdentifier = GetNewPacketIdentifier(),
TopicFilters = topicFilters
TopicFilters = topicFilters.ToList()
};

return SendAndReceiveAsync<MqttUnsubAckPacket>(unsubscribePacket);
await SendAndReceiveAsync<MqttUnsubAckPacket>(unsubscribePacket);
}

public async Task PublishAsync(IEnumerable<MqttApplicationMessage> applicationMessages)
@@ -178,9 +149,11 @@ namespace MQTTnet.Core.Client
switch (qosGroup.Key)
{
case MqttQualityOfServiceLevel.AtMostOnce:
// No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier]
await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, qosPackets);
break;
{
// No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier]
await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, qosPackets);
break;
}
case MqttQualityOfServiceLevel.AtLeastOnce:
{
foreach (var publishPacket in qosPackets)
@@ -188,6 +161,7 @@ namespace MQTTnet.Core.Client
publishPacket.PacketIdentifier = GetNewPacketIdentifier();
await SendAndReceiveAsync<MqttPubAckPacket>(publishPacket);
}

break;
}
case MqttQualityOfServiceLevel.ExactlyOnce:
@@ -195,95 +169,100 @@ namespace MQTTnet.Core.Client
foreach (var publishPacket in qosPackets)
{
publishPacket.PacketIdentifier = GetNewPacketIdentifier();
await PublishExactlyOncePacketAsync(publishPacket);
var pubRecPacket = await SendAndReceiveAsync<MqttPubRecPacket>(publishPacket).ConfigureAwait(false);
await SendAndReceiveAsync<MqttPubCompPacket>(pubRecPacket.CreateResponse<MqttPubRelPacket>()).ConfigureAwait(false);
}

break;
}
default:
throw new InvalidOperationException();
{
throw new InvalidOperationException();
}
}
}
}

private async Task PublishExactlyOncePacketAsync(MqttBasePacket publishPacket)
private void ThrowIfNotConnected()
{
var pubRecPacket = await SendAndReceiveAsync<MqttPubRecPacket>(publishPacket).ConfigureAwait(false);
await SendAndReceiveAsync<MqttPubCompPacket>(pubRecPacket.CreateResponse<MqttPubRelPacket>()).ConfigureAwait(false);
if (!IsConnected) throw new MqttCommunicationException("The client is not connected.");
}

private void ThrowIfNotConnected()
private void ThrowIfConnected(string message)
{
if (!IsConnected) throw new MqttCommunicationException("The client is not connected.");
if (IsConnected) throw new MqttProtocolViolationException(message);
}

private async Task DisconnectInternalAsync()
{
var cts = _cancellationTokenSource;
if (cts == null || cts.IsCancellationRequested)
{
return;
}

cts.Cancel(false);
cts.Dispose();
_cancellationTokenSource = null;

try
{
await _adapter.DisconnectAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false);
MqttTrace.Information(nameof(MqttClient), "Disconnected from adapter.");
}
catch (Exception exception)
{
MqttTrace.Warning(nameof(MqttClient), exception, "Error while disconnecting.");
MqttTrace.Warning(nameof(MqttClient), exception, "Error while disconnecting from adapter.");
}
finally
{
_cancellationTokenSource?.Cancel(false);
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;

IsConnected = false;

if (!_disconnectedEventSuspended)
{
_disconnectedEventSuspended = true;
Disconnected?.Invoke(this, EventArgs.Empty);
}
Disconnected?.Invoke(this, EventArgs.Empty);
}
}

private async Task ProcessReceivedPacketAsync(MqttBasePacket mqttPacket)
private async Task ProcessReceivedPacketAsync(MqttBasePacket packet)
{
try
{
if (mqttPacket is MqttPingReqPacket)
MqttTrace.Information(nameof(MqttClient), "Received <<< {0}", packet);

if (packet is MqttPingReqPacket)
{
await SendAsync(new MqttPingRespPacket());
return;
}

if (mqttPacket is MqttDisconnectPacket)
if (packet is MqttDisconnectPacket)
{
await DisconnectAsync();
return;
}

if (mqttPacket is MqttPublishPacket publishPacket)
if (packet is MqttPublishPacket publishPacket)
{
await ProcessReceivedPublishPacket(publishPacket);
return;
}

if (mqttPacket is MqttPubRelPacket pubRelPacket)
if (packet is MqttPubRelPacket pubRelPacket)
{
await ProcessReceivedPubRelPacket(pubRelPacket);
return;
}

_packetDispatcher.Dispatch(mqttPacket);
_packetDispatcher.Dispatch(packet);
}
catch (Exception exception)
{
MqttTrace.Error(nameof(MqttClient), exception, "Error while processing received packet.");
MqttTrace.Error(nameof(MqttClient), exception, "Unhandled exception while processing received packet.");
}
}

private void FireApplicationMessageReceivedEvent(MqttPublishPacket publishPacket)
{
var applicationMessage = publishPacket.ToApplicationMessage();

try
{
var applicationMessage = publishPacket.ToApplicationMessage();
ApplicationMessageReceived?.Invoke(this, new MqttApplicationMessageReceivedEventArgs(applicationMessage));
}
catch (Exception exception)
@@ -335,13 +314,13 @@ namespace MQTTnet.Core.Client

private Task SendAsync(MqttBasePacket packet)
{
return _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, packet);
return _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, packet);
}

private async Task<TResponsePacket> SendAndReceiveAsync<TResponsePacket>(MqttBasePacket requestPacket) where TResponsePacket : MqttBasePacket
{
var wait = _packetDispatcher.WaitForPacketAsync(requestPacket, typeof(TResponsePacket), _options.DefaultCommunicationTimeout);
await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, requestPacket).ConfigureAwait(false);
await _adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, requestPacket).ConfigureAwait(false);
return (TResponsePacket)await wait.ConfigureAwait(false);
}

@@ -359,17 +338,25 @@ namespace MQTTnet.Core.Client
while (!cancellationToken.IsCancellationRequested)
{
await Task.Delay(_options.KeepAlivePeriod, cancellationToken).ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested)
{
return;
}

await SendAndReceiveAsync<MqttPingRespPacket>(new MqttPingReqPacket()).ConfigureAwait(false);
}
}
catch (TaskCanceledException)
{
}
catch (MqttCommunicationException exception)
{
MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication error while receiving packets.");
MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication exception while sending/receiving keep alive packets.");
await DisconnectInternalAsync().ConfigureAwait(false);
}
catch (Exception exception)
{
MqttTrace.Warning(nameof(MqttClient), exception, "Error while sending/receiving keep alive packets.");
MqttTrace.Warning(nameof(MqttClient), exception, "Unhandled exception while sending/receiving keep alive packets.");
await DisconnectInternalAsync().ConfigureAwait(false);
}
finally
@@ -381,16 +368,23 @@ namespace MQTTnet.Core.Client
private async Task ReceivePackets(CancellationToken cancellationToken)
{
MqttTrace.Information(nameof(MqttClient), "Start receiving packets.");

try
{
while (!cancellationToken.IsCancellationRequested)
{
var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero).ConfigureAwait(false);
MqttTrace.Information(nameof(MqttClient), "Received <<< {0}", packet);
var packet = await _adapter.ReceivePacketAsync(TimeSpan.Zero, cancellationToken).ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested)
{
return;
}

StartProcessReceivedPacket(packet, cancellationToken);
}
}
catch (TaskCanceledException)
{
}
catch (MqttCommunicationException exception)
{
MqttTrace.Warning(nameof(MqttClient), exception, "MQTT communication exception while receiving packets.");
@@ -410,21 +404,21 @@ namespace MQTTnet.Core.Client
private void StartProcessReceivedPacket(MqttBasePacket packet, CancellationToken cancellationToken)
{
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
Task.Run(() => ProcessReceivedPacketAsync(packet), cancellationToken);
Task.Run(async () => await ProcessReceivedPacketAsync(packet), cancellationToken).ConfigureAwait(false);
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
}

private void StartReceivePackets()
private void StartReceivePackets(CancellationToken cancellationToken)
{
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
Task.Run(() => ReceivePackets(_cancellationTokenSource.Token), _cancellationTokenSource.Token);
Task.Run(async () => await ReceivePackets(cancellationToken), cancellationToken).ConfigureAwait(false); ;
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
}

private void StartSendKeepAliveMessages()
private void StartSendKeepAliveMessages(CancellationToken cancellationToken)
{
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
Task.Run(() => SendKeepAliveMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token);
Task.Run(async () => await SendKeepAliveMessagesAsync(cancellationToken), cancellationToken).ConfigureAwait(false);
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
}
}

+ 24
- 1
MQTTnet.Core/Client/MqttClientExtensions.cs View File

@@ -1,4 +1,8 @@
using System.Threading.Tasks;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using MQTTnet.Core.Packets;

namespace MQTTnet.Core.Client
{
@@ -6,7 +10,26 @@ namespace MQTTnet.Core.Client
{
public static Task PublishAsync(this IMqttClient client, params MqttApplicationMessage[] applicationMessages)
{
if (client == null) throw new ArgumentNullException(nameof(client));
if (applicationMessages == null) throw new ArgumentNullException(nameof(applicationMessages));

return client.PublishAsync(applicationMessages);
}

public static Task<IList<MqttSubscribeResult>> SubscribeAsync(this IMqttClient client, params TopicFilter[] topicFilters)
{
if (client == null) throw new ArgumentNullException(nameof(client));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return client.SubscribeAsync(topicFilters.ToList());
}

public static Task UnsubscribeAsync(this IMqttClient client, params string[] topicFilters)
{
if (client == null) throw new ArgumentNullException(nameof(client));
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

return client.UnsubscribeAsync(topicFilters.ToList());
}
}
}

+ 1
- 1
MQTTnet.Core/Server/MqttClientMessageQueue.cs View File

@@ -63,7 +63,7 @@ namespace MQTTnet.Core.Server
var packets = consumable.Take(_pendingPublishPackets.Count).ToList();
try
{
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, packets).ConfigureAwait(false);
await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, packets).ConfigureAwait(false);
}
catch (MqttCommunicationException exception)
{


+ 8
- 8
MQTTnet.Core/Server/MqttClientSession.cs View File

@@ -54,7 +54,7 @@ namespace MQTTnet.Core.Server
_messageQueue.Start(adapter);
while (!_cancellationTokenSource.IsCancellationRequested)
{
var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero).ConfigureAwait(false);
var packet = await adapter.ReceivePacketAsync(TimeSpan.Zero, _cancellationTokenSource.Token).ConfigureAwait(false);
await HandleIncomingPacketAsync(packet).ConfigureAwait(false);
}
}
@@ -103,12 +103,12 @@ namespace MQTTnet.Core.Server
{
if (packet is MqttSubscribePacket subscribePacket)
{
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _subscriptionsManager.Subscribe(subscribePacket));
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Subscribe(subscribePacket));
}

if (packet is MqttUnsubscribePacket unsubscribePacket)
{
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _subscriptionsManager.Unsubscribe(unsubscribePacket));
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, _subscriptionsManager.Unsubscribe(unsubscribePacket));
}

if (packet is MqttPublishPacket publishPacket)
@@ -123,7 +123,7 @@ namespace MQTTnet.Core.Server

if (packet is MqttPubRecPacket pubRecPacket)
{
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, pubRecPacket.CreateResponse<MqttPubRelPacket>());
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, pubRecPacket.CreateResponse<MqttPubRelPacket>());
}

if (packet is MqttPubAckPacket || packet is MqttPubCompPacket)
@@ -134,7 +134,7 @@ namespace MQTTnet.Core.Server

if (packet is MqttPingReqPacket)
{
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPingRespPacket());
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPingRespPacket());
}

if (packet is MqttDisconnectPacket || packet is MqttConnectPacket)
@@ -160,7 +160,7 @@ namespace MQTTnet.Core.Server
if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.AtLeastOnce)
{
_publishPacketReceivedCallback(this, publishPacket);
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier });
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubAckPacket { PacketIdentifier = publishPacket.PacketIdentifier });
}

if (publishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce)
@@ -173,7 +173,7 @@ namespace MQTTnet.Core.Server

_publishPacketReceivedCallback(this, publishPacket);

return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier });
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubRecPacket { PacketIdentifier = publishPacket.PacketIdentifier });
}

throw new MqttCommunicationException("Received a not supported QoS level.");
@@ -186,7 +186,7 @@ namespace MQTTnet.Core.Server
_unacknowledgedPublishPackets.Remove(pubRelPacket.PacketIdentifier);
}

return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier });
return Adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, _cancellationTokenSource.Token, new MqttPubCompPacket { PacketIdentifier = pubRelPacket.PacketIdentifier });
}
}
}

+ 4
- 4
MQTTnet.Core/Server/MqttClientSessionsManager.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Core.Adapter;
using MQTTnet.Core.Diagnostics;
@@ -28,8 +29,7 @@ namespace MQTTnet.Core.Server
{
try
{
var connectPacket = await eventArgs.ClientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout).ConfigureAwait(false) as MqttConnectPacket;
if (connectPacket == null)
if (!(await eventArgs.ClientAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false) is MqttConnectPacket connectPacket))
{
throw new MqttProtocolViolationException("The first packet from a client must be a 'CONNECT' packet [MQTT-3.1.0-1].");
}
@@ -40,7 +40,7 @@ namespace MQTTnet.Core.Server
var connectReturnCode = ValidateConnection(connectPacket);
if (connectReturnCode != MqttConnectReturnCode.ConnectionAccepted)
{
await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttConnAckPacket
await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, CancellationToken.None, new MqttConnAckPacket
{
ConnectReturnCode = connectReturnCode
}).ConfigureAwait(false);
@@ -50,7 +50,7 @@ namespace MQTTnet.Core.Server

var clientSession = GetOrCreateClientSession(connectPacket);

await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, new MqttConnAckPacket
await eventArgs.ClientAdapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, CancellationToken.None, new MqttConnAckPacket
{
ConnectReturnCode = connectReturnCode,
IsSessionPresent = clientSession.IsExistingSession


+ 41
- 3
Tests/MQTTnet.Core.Tests/ExtensionTests.cs View File

@@ -11,23 +11,61 @@ namespace MQTTnet.Core.Tests
{
[ExpectedException(typeof(MqttCommunicationTimedOutException))]
[TestMethod]
public async Task TestTimeoutAfter()
public async Task TimeoutAfter()
{
await Task.Delay(TimeSpan.FromMilliseconds(500)).TimeoutAfter(TimeSpan.FromMilliseconds(100));
}

[ExpectedException(typeof(MqttCommunicationTimedOutException))]
[TestMethod]
public async Task TestTimeoutAfterWithResult()
public async Task TimeoutAfterWithResult()
{
await Task.Delay(TimeSpan.FromMilliseconds(500)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(100));
}

[TestMethod]
public async Task TestTimeoutAfterCompleteInTime()
public async Task TimeoutAfterCompleteInTime()
{
var result = await Task.Delay(TimeSpan.FromMilliseconds(100)).ContinueWith(t => 5).TimeoutAfter(TimeSpan.FromMilliseconds(500));
Assert.AreEqual(5, result);
}

[TestMethod]
public async Task TimeoutAfterWithInnerException()
{
try
{
await Task.Run(() =>
{
var iis = new int[0];
iis[1] = 0;
}).TimeoutAfter(TimeSpan.FromSeconds(1));

Assert.Fail();
}
catch (MqttCommunicationException e)
{
Assert.IsTrue(e.InnerException is IndexOutOfRangeException);
}
}

[TestMethod]
public async Task TimeoutAfterWithInnerExceptionWithResult()
{
try
{
var r = await Task.Run(() =>
{
var iis = new int[0];
return iis[1];
}).TimeoutAfter(TimeSpan.FromSeconds(1));

Assert.Fail();
}
catch (MqttCommunicationException e)
{
Assert.IsTrue(e.InnerException is IndexOutOfRangeException);
}
}
}
}

+ 2
- 1
Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs View File

@@ -1,6 +1,7 @@
using System;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Core.Adapter;
@@ -442,7 +443,7 @@ namespace MQTTnet.Core.Tests

using (var headerStream = new MemoryStream(buffer1))
{
var header = MqttPacketReader.ReadHeaderFromSource(headerStream);
var header = MqttPacketReader.ReadHeaderFromSource(headerStream, CancellationToken.None);

using (var bodyStream = new MemoryStream(buffer1, (int)headerStream.Position, header.BodyLength))
{


+ 2
- 2
Tests/MQTTnet.Core.Tests/MqttServerTests.cs View File

@@ -91,7 +91,7 @@ namespace MQTTnet.Core.Tests
await Task.Delay(500);
Assert.AreEqual(1, receivedMessagesCount);

await c1.Unsubscribe("a");
await c1.UnsubscribeAsync("a");
await c2.PublishAsync(message);

await Task.Delay(500);
@@ -158,7 +158,7 @@ namespace MQTTnet.Core.Tests
await c2.PublishAsync(new MqttApplicationMessage(topic, new byte[0], qualityOfServiceLevel, false));

await Task.Delay(500);
await c1.Unsubscribe(topicFilter);
await c1.UnsubscribeAsync(topicFilter);

await Task.Delay(500);



+ 3
- 8
Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs View File

@@ -28,7 +28,7 @@ namespace MQTTnet.Core.Tests
return Task.FromResult(0);
}

public Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets)
public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable<MqttBasePacket> packets)
{
ThrowIfPartnerIsNull();

@@ -40,16 +40,11 @@ namespace MQTTnet.Core.Tests
return Task.FromResult(0);
}

public Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout)
public Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfPartnerIsNull();

return Task.Run(() => _incomingPackets.Take());
}

public IEnumerable<MqttBasePacket> ReceivePackets(CancellationToken cancellationToken)
{
return _incomingPackets.GetConsumingEnumerable();
return Task.Run(() => _incomingPackets.Take(), cancellationToken);
}

private void SendPacketInternal(MqttBasePacket packet)


Loading…
Cancel
Save