diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs index 7a5c02a..8f03e65 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/IMqttChannelAdapter.cs @@ -15,7 +15,7 @@ namespace MQTTnet.Adapter Task DisconnectAsync(TimeSpan timeout); - Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets); + Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets); Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken); } diff --git a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs index 123b88d..7fffa16 100644 --- a/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Adapter/MqttChannelAdapter.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net.Sockets; +using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -49,59 +51,45 @@ namespace MQTTnet.Adapter return ExecuteAndWrapExceptionAsync(() => _channel.DisconnectAsync().TimeoutAfter(timeout)); } - public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) + public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) { - ThrowIfDisposed(); - - return ExecuteAndWrapExceptionAsync(async () => + for(var i=0;i("TX >>> {0} [Timeout={1}]", packet, timeout); + public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) + { + ThrowIfDisposed(); - var chunks = PacketSerializer.Serialize(packet); - foreach (var chunk in chunks) - { - if (cancellationToken.IsCancellationRequested) - { - return; - } + if (packet == null) + { + return Task.FromResult(0); + } - await _channel.SendStream.WriteAsync(chunk.Array, chunk.Offset, chunk.Count, cancellationToken).ConfigureAwait(false); - } - } + return ExecuteAndWrapExceptionAsync(async () => + { + if (cancellationToken.IsCancellationRequested) + { + return; + } + - if (cancellationToken.IsCancellationRequested) - { - return; - } + _logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); - if (timeout > TimeSpan.Zero) - { - await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); - } - else - { - await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); - } - } - finally + var packetData = PacketSerializer.Serialize(packet); + if (cancellationToken.IsCancellationRequested) { - _semaphore.Release(); + return; } + await _channel.SendStream.WriteAsync( + packetData.Array, + packetData.Offset, + (int)packetData.Count, + cancellationToken).ConfigureAwait(false); + }); } @@ -121,7 +109,23 @@ namespace MQTTnet.Adapter var timeoutCts = new CancellationTokenSource(timeout); var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); - receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, linkedCts.Token).ConfigureAwait(false); + try + { + receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, linkedCts.Token).ConfigureAwait(false); + } + catch(OperationCanceledException ex) + { + //check if timed out + if(linkedCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) + { + //only timeout token was cancelled + throw new MqttCommunicationTimedOutException(ex); + } + else + { + throw; + } + } } else { diff --git a/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs b/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs index 7d0adcd..86f58b3 100644 --- a/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs +++ b/Frameworks/MQTTnet.NetStandard/Exceptions/MqttCommunicationTimedOutException.cs @@ -1,6 +1,11 @@ -namespace MQTTnet.Exceptions +using System; + +namespace MQTTnet.Exceptions { public sealed class MqttCommunicationTimedOutException : MqttCommunicationException { + public MqttCommunicationTimedOutException() { } + public MqttCommunicationTimedOutException(Exception innerException) : base(innerException) { } + } } diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index 8d43d87..85fe31d 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -60,6 +60,7 @@ namespace MQTTnet.Implementations if (_socket == null) { _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + } #if NET452 || NET461 @@ -68,6 +69,8 @@ namespace MQTTnet.Implementations await _socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); #endif + _socket.NoDelay = true; + if (_options.TlsOptions.UseTls) { _sslStream = new SslStream(new NetworkStream(_socket, true), false, InternalUserCertificateValidationCallback); diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs index dc33a3f..8f896d5 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpServerAdapter.cs @@ -39,6 +39,8 @@ namespace MQTTnet.Implementations if (options.DefaultEndpointOptions.IsEnabled) { _defaultEndpointSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); + _defaultEndpointSocket.NoDelay = true; + _defaultEndpointSocket.Bind(new IPEndPoint(options.DefaultEndpointOptions.BoundIPAddress, options.GetDefaultEndpointPort())); _defaultEndpointSocket.Listen(options.ConnectionBacklog); @@ -102,7 +104,7 @@ namespace MQTTnet.Implementations #else var clientSocket = await _defaultEndpointSocket.AcceptAsync().ConfigureAwait(false); #endif - + clientSocket.NoDelay=true; var clientAdapter = new MqttChannelAdapter(new MqttTcpChannel(clientSocket, null), new MqttPacketSerializer(), _logger); ClientAccepted?.Invoke(this, new MqttServerAdapterClientAcceptedEventArgs(clientAdapter)); } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs index 6577b0a..6afdd98 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/IMqttPacketSerializer.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Serializer { MqttProtocolVersion ProtocolVersion { get; set; } - ICollection> Serialize(MqttBasePacket mqttPacket); + ArraySegment Serialize(MqttBasePacket mqttPacket); MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body); } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs index 0904c75..75e7239 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketSerializer.cs @@ -16,29 +16,36 @@ namespace MQTTnet.Serializer public MqttProtocolVersion ProtocolVersion { get; set; } = MqttProtocolVersion.V311; - public ICollection> Serialize(MqttBasePacket packet) + public ArraySegment Serialize(MqttBasePacket packet) { if (packet == null) throw new ArgumentNullException(nameof(packet)); using (var stream = new MemoryStream(128)) using (var writer = new MqttPacketWriter(stream)) { + //leave enough head space for max header (fixed + 4 variable remaining lenght) + stream.Position = 5; + var fixedHeader = SerializePacket(packet, writer); - var remainingLength = (int)stream.Length; - writer.Write(fixedHeader); - MqttPacketWriter.WriteRemainingLength(remainingLength, writer); - var headerLength = (int)stream.Length - remainingLength; + var remainingLength = MqttPacketWriter.GetRemainingLength((int)stream.Length-5); + + var headerSize = remainingLength.Length + 1; + var headerOffset = 5 - headerSize; + + //position curson on correct offset on beginining of array + stream.Position = headerOffset; + + //write header + writer.Write(fixedHeader); + writer.Write(remainingLength,0,remainingLength.Length); + #if NET461 || NET452 || NETSTANDARD2_0 var buffer = stream.GetBuffer(); #else var buffer = stream.ToArray(); #endif - return new List> - { - new ArraySegment(buffer, remainingLength, headerLength), - new ArraySegment(buffer, 0, remainingLength) - }; + return new ArraySegment(buffer, headerOffset, (int)stream.Length- headerOffset); } } diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs index cb3d458..e7cfa4c 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketWriter.cs @@ -1,5 +1,6 @@ using System; using System.IO; +using System.Linq; using System.Text; using MQTTnet.Protocol; @@ -55,14 +56,16 @@ namespace MQTTnet.Serializer Write(value); } - public static void WriteRemainingLength(int length, BinaryWriter target) + public static byte[] GetRemainingLength(int length) { - if (length == 0) + if (length <= 0) { - target.Write((byte)0); - return; + return new [] { (byte)0 }; } + var bytes = new byte[4]; + int arraySize = 0; + // Alorithm taken from http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html. var x = length; do @@ -74,8 +77,12 @@ namespace MQTTnet.Serializer encodedByte = encodedByte | 128; } - target.Write((byte)encodedByte); + bytes[arraySize] = (byte)encodedByte; + + arraySize++; } while (x > 0); + + return bytes.Take(arraySize).ToArray(); } } } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 1111d90..a7421c6 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -403,9 +403,9 @@ namespace MQTTnet.Core.Tests private static void SerializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) { var serializer = new MqttPacketSerializer { ProtocolVersion = protocolVersion }; - var chunks = serializer.Serialize(packet); + var data = serializer.Serialize(packet); - Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(chunks))); + Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(Join(data))); } private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) @@ -438,5 +438,16 @@ namespace MQTTnet.Core.Tests return buffer.ToArray(); } + + private static byte[] Join(params ArraySegment[] chunks) + { + var buffer = new MemoryStream(); + foreach (var chunk in chunks) + { + buffer.Write(chunk.Array, chunk.Offset, chunk.Count); + } + + return buffer.ToArray(); + } } } diff --git a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs index a898f76..f564e51 100644 --- a/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs +++ b/Tests/MQTTnet.Core.Tests/TestMqttCommunicationAdapter.cs @@ -31,7 +31,7 @@ namespace MQTTnet.Core.Tests return Task.FromResult(0); } - public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) + public Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket[] packets) { ThrowIfPartnerIsNull();