@@ -15,7 +15,7 @@ namespace MQTTnet.Adapter | |||||
Task DisconnectAsync(TimeSpan timeout); | Task DisconnectAsync(TimeSpan timeout); | ||||
Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable<MqttBasePacket> packets); | |||||
Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets); | |||||
Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); | Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); | ||||
} | } | ||||
@@ -1,7 +1,9 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | using System.IO; | ||||
using System.Linq; | |||||
using System.Net.Sockets; | using System.Net.Sockets; | ||||
using System.Runtime.ExceptionServices; | |||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Threading; | using System.Threading; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
@@ -49,59 +51,45 @@ namespace MQTTnet.Adapter | |||||
return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); | return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); | ||||
} | } | ||||
public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable<MqttBasePacket> packets) | |||||
public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) | |||||
{ | { | ||||
ThrowIfDisposed(); | |||||
return ExecuteAndWrapExceptionAsync(async () => | |||||
for(var i=0;i<packets.Length;i++) | |||||
{ | { | ||||
await _semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); | |||||
try | |||||
{ | |||||
foreach (var packet in packets) | |||||
{ | |||||
if (cancellationToken.IsCancellationRequested) | |||||
{ | |||||
return; | |||||
} | |||||
await SendPacketsAsync(timeout, cancellationToken, packets[i]).ConfigureAwait(false); | |||||
} | |||||
if (packet == null) | |||||
{ | |||||
continue; | |||||
} | |||||
} | |||||
_logger.Verbose<MqttChannelAdapter>("TX >>> {0} [Timeout={1}]", packet, timeout); | |||||
public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) | |||||
{ | |||||
ThrowIfDisposed(); | |||||
var chunks = PacketSerializer.Serialize(packet); | |||||
foreach (var chunk in chunks) | |||||
{ | |||||
if (cancellationToken.IsCancellationRequested) | |||||
{ | |||||
return; | |||||
} | |||||
if (packet == null) | |||||
{ | |||||
return Task.FromResult(0); | |||||
} | |||||
await _channel.SendStream.WriteAsync(chunk.Array, chunk.Offset, chunk.Count, cancellationToken).ConfigureAwait(false); | |||||
} | |||||
} | |||||
return ExecuteAndWrapExceptionAsync(async () => | |||||
{ | |||||
if (cancellationToken.IsCancellationRequested) | |||||
{ | |||||
return; | |||||
} | |||||
if (cancellationToken.IsCancellationRequested) | |||||
{ | |||||
return; | |||||
} | |||||
_logger.Verbose<MqttChannelAdapter>("TX >>> {0} [Timeout={1}]", packet, timeout); | |||||
if (timeout > TimeSpan.Zero) | |||||
{ | |||||
await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); | |||||
} | |||||
else | |||||
{ | |||||
await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); | |||||
} | |||||
} | |||||
finally | |||||
var packetData = PacketSerializer.Serialize(packet); | |||||
if (cancellationToken.IsCancellationRequested) | |||||
{ | { | ||||
_semaphore.Release(); | |||||
return; | |||||
} | } | ||||
await _channel.SendStream.WriteAsync( | |||||
packetData.Array, | |||||
packetData.Offset, | |||||
(int)packetData.Count, | |||||
cancellationToken).ConfigureAwait(false); | |||||
}); | }); | ||||
} | } | ||||
@@ -121,7 +109,23 @@ namespace MQTTnet.Adapter | |||||
var timeoutCts = new CancellationTokenSource(timeout); | var timeoutCts = new CancellationTokenSource(timeout); | ||||
var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); | var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); | ||||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, linkedCts.Token).ConfigureAwait(false); | |||||
try | |||||
{ | |||||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, linkedCts.Token).ConfigureAwait(false); | |||||
} | |||||
catch(OperationCanceledException ex) | |||||
{ | |||||
//check if timed out | |||||
if(linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) | |||||
{ | |||||
//only timeout token was cancelled | |||||
throw new MqttCommunicationTimedOutException(ex); | |||||
} | |||||
else | |||||
{ | |||||
throw; | |||||
} | |||||
} | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -1,6 +1,11 @@ | |||||
namespace MQTTnet.Exceptions | |||||
using System; | |||||
namespace MQTTnet.Exceptions | |||||
{ | { | ||||
public sealed class MqttCommunicationTimedOutException : MqttCommunicationException | public sealed class MqttCommunicationTimedOutException : MqttCommunicationException | ||||
{ | { | ||||
public MqttCommunicationTimedOutException() { } | |||||
public MqttCommunicationTimedOutException(Exception innerException) : base(innerException) { } | |||||
} | } | ||||
} | } |
@@ -60,6 +60,7 @@ namespace MQTTnet.Implementations | |||||
if (_socket == null) | if (_socket == null) | ||||
{ | { | ||||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | ||||
} | } | ||||
#if NET452 || NET461 | #if NET452 || NET461 | ||||
@@ -68,6 +69,8 @@ namespace MQTTnet.Implementations | |||||
await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); | await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); | ||||
#endif | #endif | ||||
_socket.NoDelay = true; | |||||
if (_options.TlsOptions.UseTls) | if (_options.TlsOptions.UseTls) | ||||
{ | { | ||||
_sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); | _sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); | ||||
@@ -39,6 +39,8 @@ namespace MQTTnet.Implementations | |||||
if (options.DefaultEndpointOptions.IsEnabled) | if (options.DefaultEndpointOptions.IsEnabled) | ||||
{ | { | ||||
_defaultEndpointSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); | _defaultEndpointSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); | ||||
_defaultEndpointSocket.NoDelay = true; | |||||
_defaultEndpointSocket.Bind(new IPEndPoint(options.DefaultEndpointOptions.BoundIPAddress, options.GetDefaultEndpointPort())); | _defaultEndpointSocket.Bind(new IPEndPoint(options.DefaultEndpointOptions.BoundIPAddress, options.GetDefaultEndpointPort())); | ||||
_defaultEndpointSocket.Listen(options.ConnectionBacklog); | _defaultEndpointSocket.Listen(options.ConnectionBacklog); | ||||
@@ -102,7 +104,7 @@ namespace MQTTnet.Implementations | |||||
#else | #else | ||||
var clientSocket = await _defaultEndpointSocket.AcceptAsync().ConfigureAwait(false); | var clientSocket = await _defaultEndpointSocket.AcceptAsync().ConfigureAwait(false); | ||||
#endif | #endif | ||||
clientSocket.NoDelay=true; | |||||
var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger); | var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger); | ||||
ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); | ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); | ||||
} | } | ||||
@@ -9,7 +9,7 @@ namespace MQTTnet.Serializer | |||||
{ | { | ||||
MqttProtocolVersion ProtocolVersion { get; set; } | MqttProtocolVersion ProtocolVersion { get; set; } | ||||
ICollection<ArraySegment<byte>> Serialize(MqttBasePacket mqttPacket); | |||||
ArraySegment<byte> Serialize(MqttBasePacket mqttPacket); | |||||
MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); | MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); | ||||
} | } |
@@ -16,29 +16,36 @@ namespace MQTTnet.Serializer | |||||
public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; | public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; | ||||
public ICollection<ArraySegment<byte>> Serialize(MqttBasePacket packet) | |||||
public ArraySegment<byte> Serialize(MqttBasePacket packet) | |||||
{ | { | ||||
if (packet == null) throw new ArgumentNullException(nameof(packet)); | if (packet == null) throw new ArgumentNullException(nameof(packet)); | ||||
using (var stream = new MemoryStream(128)) | using (var stream = new MemoryStream(128)) | ||||
using (var writer = new MqttPacketWriter(stream)) | using (var writer = new MqttPacketWriter(stream)) | ||||
{ | { | ||||
//leave enough head space for max header (fixed + 4 variable remaining lenght) | |||||
stream.Position = 5; | |||||
var fixedHeader = SerializePacket(packet, writer); | var fixedHeader = SerializePacket(packet, writer); | ||||
var remainingLength = (int)stream.Length; | |||||
writer.Write(fixedHeader); | |||||
MqttPacketWriter.WriteRemainingLength(remainingLength, writer); | |||||
var headerLength = (int)stream.Length - remainingLength; | |||||
var remainingLength = MqttPacketWriter.GetRemainingLength((int)stream.Length-5); | |||||
var headerSize = remainingLength.Length + 1; | |||||
var headerOffset = 5 - headerSize; | |||||
//position curson on correct offset on beginining of array | |||||
stream.Position = headerOffset; | |||||
//write header | |||||
writer.Write(fixedHeader); | |||||
writer.Write(remainingLength,0,remainingLength.Length); | |||||
#if NET461 || NET452 || NETSTANDARD2_0 | #if NET461 || NET452 || NETSTANDARD2_0 | ||||
var buffer = stream.GetBuffer(); | var buffer = stream.GetBuffer(); | ||||
#else | #else | ||||
var buffer = stream.ToArray(); | var buffer = stream.ToArray(); | ||||
#endif | #endif | ||||
return new List<ArraySegment<byte>> | |||||
{ | |||||
new ArraySegment<byte>(buffer, remainingLength, headerLength), | |||||
new ArraySegment<byte>(buffer, 0, remainingLength) | |||||
}; | |||||
return new ArraySegment<byte>(buffer, headerOffset, (int)stream.Length- headerOffset); | |||||
} | } | ||||
} | } | ||||
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.IO; | using System.IO; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using MQTTnet.Protocol; | using MQTTnet.Protocol; | ||||
@@ -55,14 +56,16 @@ namespace MQTTnet.Serializer | |||||
Write(value); | Write(value); | ||||
} | } | ||||
public static void WriteRemainingLength(int length, BinaryWriter target) | |||||
public static byte[] GetRemainingLength(int length) | |||||
{ | { | ||||
if (length == 0) | |||||
if (length <= 0) | |||||
{ | { | ||||
target.Write((byte)0); | |||||
return; | |||||
return new [] { (byte)0 }; | |||||
} | } | ||||
var bytes = new byte[4]; | |||||
int arraySize = 0; | |||||
// Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. | // Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. | ||||
var x = length; | var x = length; | ||||
do | do | ||||
@@ -74,8 +77,12 @@ namespace MQTTnet.Serializer | |||||
encodedByte = encodedByte | 128; | encodedByte = encodedByte | 128; | ||||
} | } | ||||
target.Write((byte)encodedByte); | |||||
bytes[arraySize] = (byte)encodedByte; | |||||
arraySize++; | |||||
} while (x > 0); | } while (x > 0); | ||||
return bytes.Take(arraySize).ToArray(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -403,9 +403,9 @@ namespace MQTTnet.Core.Tests | |||||
private static void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | private static void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | ||||
{ | { | ||||
var serializer = new MqttPacketSerializer { ProtocolVersion = protocolVersion }; | var serializer = new MqttPacketSerializer { ProtocolVersion = protocolVersion }; | ||||
var chunks = serializer.Serialize(packet); | |||||
var data = serializer.Serialize(packet); | |||||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(chunks))); | |||||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(data))); | |||||
} | } | ||||
private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | ||||
@@ -438,5 +438,16 @@ namespace MQTTnet.Core.Tests | |||||
return buffer.ToArray(); | return buffer.ToArray(); | ||||
} | } | ||||
private static byte[] Join(params ArraySegment<byte>[] chunks) | |||||
{ | |||||
var buffer = new MemoryStream(); | |||||
foreach (var chunk in chunks) | |||||
{ | |||||
buffer.Write(chunk.Array, chunk.Offset, chunk.Count); | |||||
} | |||||
return buffer.ToArray(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -31,7 +31,7 @@ namespace MQTTnet.Core.Tests | |||||
return Task.FromResult(0); | return Task.FromResult(0); | ||||
} | } | ||||
public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable<MqttBasePacket> packets) | |||||
public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) | |||||
{ | { | ||||
ThrowIfPartnerIsNull(); | ThrowIfPartnerIsNull(); | ||||