diff --git a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs index ef07848..5359d99 100644 --- a/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs +++ b/Frameworks/MQTTnet.NetStandard/Serializer/MqttPacketReader.cs @@ -24,14 +24,17 @@ namespace MQTTnet.Serializer public static async Task ReadHeaderFromSourceAsync(Stream stream, CancellationToken cancellationToken) { - byte[] singleByteBuf = new byte[1]; - var readCount = await stream.ReadAsync(singleByteBuf, 0, singleByteBuf.Length).ConfigureAwait(false); + // Wait for the next package which starts with the header. At this point there will probably + // some large delay and thus the thread should be put back to the pool (await). So ReadByte() + // is not an option here. + var buffer = new byte[1]; + var readCount = await stream.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); if (readCount <= 0) { return null; } - var fixedHeader = singleByteBuf[0]; + var fixedHeader = buffer[0]; var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); var bodyLength = await ReadBodyLengthFromSourceAsync(stream, cancellationToken).ConfigureAwait(false); @@ -88,8 +91,7 @@ namespace MQTTnet.Serializer var value = 0; byte encodedByte; - byte[] singleByteBuf = new byte[1]; - + var buffer = new byte[1]; var readBytes = new List(); do { @@ -98,13 +100,13 @@ namespace MQTTnet.Serializer throw new TaskCanceledException(); } - int readCount = await stream.ReadAsync(singleByteBuf, 0, singleByteBuf.Length).ConfigureAwait(false); + var readCount = await stream.ReadAsync(buffer, 0, 1, cancellationToken).ConfigureAwait(false); if (readCount <= 0) { throw new MqttCommunicationException("Connection closed while reading remaining length data."); } - encodedByte = singleByteBuf[0]; + encodedByte = buffer[0]; readBytes.Add(encodedByte); value += (byte)(encodedByte & 127) * multiplier; diff --git a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs index e092d3b..62a4d78 100644 --- a/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs +++ b/Frameworks/MQTTnet.NetStandard/Server/MqttClientPendingMessagesQueue.cs @@ -43,6 +43,7 @@ namespace MQTTnet.Server _queue.Enqueue(packet); _queueWaitSemaphore.Release(); + _logger.Trace("Enqueued packet (ClientId: {0}).", _session.ClientId); } @@ -70,9 +71,11 @@ namespace MQTTnet.Server try { await _queueWaitSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - if (!_queue.TryDequeue(out packet)) { + if (!_queue.TryDequeue(out packet)) + { throw new InvalidOperationException(); // should not happen } + await adapter.SendPacketsAsync(_options.DefaultCommunicationTimeout, cancellationToken, packet).ConfigureAwait(false); _logger.Trace("Enqueued packet sent (ClientId: {0}).", _session.ClientId); diff --git a/Tests/MQTTnet.Core.Tests/MqttClientTests.cs b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs new file mode 100644 index 0000000..d434325 --- /dev/null +++ b/Tests/MQTTnet.Core.Tests/MqttClientTests.cs @@ -0,0 +1,37 @@ +using System; +using System.Net.Sockets; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Client; +using MQTTnet.Exceptions; + +namespace MQTTnet.Core.Tests +{ + [TestClass] + public class MqttClientTests + { + + [TestMethod] + public async Task ClientDisconnectException() + { + var factory = new MqttFactory(); + var client = factory.CreateMqttClient(); + + var exceptionIsCorrect = false; + client.Disconnected += (s, e) => + { + exceptionIsCorrect = e.Exception is MqttCommunicationException c && c.InnerException is SocketException; + }; + + try + { + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("wrong-server").Build()); + } + catch + { + } + + Assert.IsTrue(exceptionIsCorrect); + } + } +} diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs index 606d1da..4d0e9b0 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketReaderTests.cs @@ -12,7 +12,7 @@ namespace MQTTnet.Core.Tests public void MqttPacketReader_EmptyStream() { var memStream = new MemoryStream(); - var header = MqttPacketReader.ReadHeaderFromSource(memStream, CancellationToken.None); + var header = MqttPacketReader.ReadHeaderFromSourceAsync(memStream, CancellationToken.None).GetAwaiter().GetResult(); Assert.IsNull(header); } diff --git a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs index 9f51b86..5a75490 100644 --- a/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs @@ -405,7 +405,7 @@ namespace MQTTnet.Core.Tests using (var headerStream = new MemoryStream(Join(buffer1))) { - var header = MqttPacketReader.ReadHeaderFromSource(headerStream, CancellationToken.None); + var header = MqttPacketReader.ReadHeaderFromSourceAsync(headerStream, CancellationToken.None).GetAwaiter().GetResult(); using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.BodyLength)) {