@@ -23,6 +23,10 @@ namespace MQTTnet.Adapter | |||||
private readonly IMqttNetChildLogger _logger; | private readonly IMqttNetChildLogger _logger; | ||||
private readonly IMqttChannel _channel; | private readonly IMqttChannel _channel; | ||||
private readonly byte[] _fixedHeaderBuffer = new byte[2]; | |||||
private readonly byte[] _singleByteBuffer = new byte[1]; | |||||
private bool _isDisposed; | private bool _isDisposed; | ||||
public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetChildLogger logger) | public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetChildLogger logger) | ||||
@@ -163,7 +167,7 @@ namespace MQTTnet.Adapter | |||||
private async Task<ReceivedMqttPacket> ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) | private async Task<ReceivedMqttPacket> ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) | ||||
{ | { | ||||
var fixedHeader = await MqttPacketReader.ReadFixedHeaderAsync(channel, cancellationToken).ConfigureAwait(false); | |||||
var fixedHeader = await MqttPacketReader.ReadFixedHeaderAsync(channel, _fixedHeaderBuffer, _singleByteBuffer, cancellationToken).ConfigureAwait(false); | |||||
try | try | ||||
{ | { | ||||
@@ -4,7 +4,7 @@ | |||||
{ | { | ||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
return "PubAck"; | |||||
return $"PubAck [PacketIdentifier={PacketIdentifier}]"; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -20,6 +20,7 @@ namespace MQTTnet.Serializer | |||||
public byte ReadByte() | public byte ReadByte() | ||||
{ | { | ||||
ValidateReceiveBuffer(1); | |||||
return _buffer[_offset++]; | return _buffer[_offset++]; | ||||
} | } | ||||
@@ -30,6 +31,8 @@ namespace MQTTnet.Serializer | |||||
public ushort ReadUInt16() | public ushort ReadUInt16() | ||||
{ | { | ||||
ValidateReceiveBuffer(2); | |||||
var msb = _buffer[_offset++]; | var msb = _buffer[_offset++]; | ||||
var lsb = _buffer[_offset++]; | var lsb = _buffer[_offset++]; | ||||
@@ -40,12 +43,22 @@ namespace MQTTnet.Serializer | |||||
{ | { | ||||
var length = ReadUInt16(); | var length = ReadUInt16(); | ||||
ValidateReceiveBuffer(length); | |||||
var result = new ArraySegment<byte>(_buffer, _offset, length); | var result = new ArraySegment<byte>(_buffer, _offset, length); | ||||
_offset += length; | _offset += length; | ||||
return result; | return result; | ||||
} | } | ||||
private void ValidateReceiveBuffer(ushort length) | |||||
{ | |||||
if (_buffer.Length < _offset + length) | |||||
{ | |||||
throw new ArgumentOutOfRangeException(nameof(_buffer), $"expected at least {_offset + length} bytes but there are only {_buffer.Length} bytes"); | |||||
} | |||||
} | |||||
public string ReadStringWithLengthPrefix() | public string ReadStringWithLengthPrefix() | ||||
{ | { | ||||
var buffer = ReadWithLengthPrefix(); | var buffer = ReadWithLengthPrefix(); | ||||
@@ -9,19 +9,13 @@ namespace MQTTnet.Serializer | |||||
{ | { | ||||
public static class MqttPacketReader | public static class MqttPacketReader | ||||
{ | { | ||||
[ThreadStatic] | |||||
private static byte[] _fixedHeaderBuffer; | |||||
[ThreadStatic] | |||||
private static byte[] _singleByteBuffer; | |||||
public static async Task<MqttFixedHeader> ReadFixedHeaderAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||||
public static async Task<MqttFixedHeader> ReadFixedHeaderAsync(IMqttChannel channel, byte[] fixedHeaderBuffer, byte[] singleByteBuffer, CancellationToken cancellationToken) | |||||
{ | { | ||||
// The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. | // The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. | ||||
// So in all cases at least 2 bytes must be read for a complete MQTT packet. | // So in all cases at least 2 bytes must be read for a complete MQTT packet. | ||||
// async/await is used here because the next packet is received in a couple of minutes so the performance | // async/await is used here because the next packet is received in a couple of minutes so the performance | ||||
// impact is acceptable according to a useless waiting thread. | // impact is acceptable according to a useless waiting thread. | ||||
var buffer = InitializeFixedHeaderBuffer(); | |||||
var buffer = fixedHeaderBuffer; | |||||
var totalBytesRead = 0; | var totalBytesRead = 0; | ||||
while (totalBytesRead < buffer.Length) | while (totalBytesRead < buffer.Length) | ||||
@@ -41,12 +35,12 @@ namespace MQTTnet.Serializer | |||||
{ | { | ||||
return new MqttFixedHeader(buffer[0], 0); | return new MqttFixedHeader(buffer[0], 0); | ||||
} | } | ||||
var bodyLength = ReadBodyLength(channel, buffer[1], cancellationToken); | |||||
var bodyLength = ReadBodyLength(channel, buffer[1], singleByteBuffer, cancellationToken); | |||||
return new MqttFixedHeader(buffer[0], bodyLength); | return new MqttFixedHeader(buffer[0], bodyLength); | ||||
} | } | ||||
private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) | |||||
private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, byte[] singleByteBuffer, CancellationToken cancellationToken) | |||||
{ | { | ||||
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | // Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | ||||
var multiplier = 128; | var multiplier = 128; | ||||
@@ -61,7 +55,7 @@ namespace MQTTnet.Serializer | |||||
// is too big for reading 1 byte in a row. We expect that the remaining data was sent | // is too big for reading 1 byte in a row. We expect that the remaining data was sent | ||||
// directly after the initial bytes. If the client disconnects just in this moment we | // directly after the initial bytes. If the client disconnects just in this moment we | ||||
// will get an exception anyway. | // will get an exception anyway. | ||||
encodedByte = ReadByte(channel, cancellationToken); | |||||
encodedByte = ReadByte(channel, singleByteBuffer, cancellationToken); | |||||
value += (byte)(encodedByte & 127) * multiplier; | value += (byte)(encodedByte & 127) * multiplier; | ||||
if (multiplier > 128 * 128 * 128) | if (multiplier > 128 * 128 * 128) | ||||
@@ -75,27 +69,16 @@ namespace MQTTnet.Serializer | |||||
return value; | return value; | ||||
} | } | ||||
private static byte ReadByte(IMqttChannel channel, CancellationToken cancellationToken) | |||||
private static byte ReadByte(IMqttChannel channel, byte[] singleByteBuffer, CancellationToken cancellationToken) | |||||
{ | { | ||||
var buffer = InitializeSingleByteBuffer(); | |||||
var readCount = channel.ReadAsync(buffer, 0, 1, cancellationToken).GetAwaiter().GetResult(); | |||||
var readCount = channel.ReadAsync(singleByteBuffer, 0, 1, cancellationToken).GetAwaiter().GetResult(); | |||||
if (readCount <= 0) | if (readCount <= 0) | ||||
{ | { | ||||
cancellationToken.ThrowIfCancellationRequested(); | cancellationToken.ThrowIfCancellationRequested(); | ||||
ExceptionHelper.ThrowGracefulSocketClose(); | ExceptionHelper.ThrowGracefulSocketClose(); | ||||
} | } | ||||
return buffer[0]; | |||||
} | |||||
private static byte[] InitializeFixedHeaderBuffer() | |||||
{ | |||||
return _fixedHeaderBuffer ?? (_fixedHeaderBuffer = new byte[2]); | |||||
} | |||||
private static byte[] InitializeSingleByteBuffer() | |||||
{ | |||||
return _singleByteBuffer ?? (_singleByteBuffer = new byte[1]); | |||||
return singleByteBuffer[0]; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -190,7 +190,7 @@ namespace MQTTnet.Serializer | |||||
var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; | var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; | ||||
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); | var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); | ||||
var dup = (receivedMqttPacket.FixedHeader & 0x3) > 0; | |||||
var dup = (receivedMqttPacket.FixedHeader & 0x8) > 0; | |||||
var topic = receivedMqttPacket.Body.ReadStringWithLengthPrefix(); | var topic = receivedMqttPacket.Body.ReadStringWithLengthPrefix(); | ||||
@@ -99,7 +99,7 @@ namespace MQTTnet.Serializer | |||||
public void Reset() | public void Reset() | ||||
{ | { | ||||
Length = 0; | |||||
Length = 5; | |||||
} | } | ||||
public void Seek(int offset) | public void Seek(int offset) | ||||
@@ -47,12 +47,14 @@ namespace MQTTnet.Benchmarks | |||||
public void Deserialize_10000_Messages() | public void Deserialize_10000_Messages() | ||||
{ | { | ||||
var channel = new BenchmarkMqttChannel(_serializedPacket); | var channel = new BenchmarkMqttChannel(_serializedPacket); | ||||
var fixedHeader = new byte[2]; | |||||
var singleByteBuffer = new byte[1]; | |||||
for (var i = 0; i < 10000; i++) | for (var i = 0; i < 10000; i++) | ||||
{ | { | ||||
channel.Reset(); | channel.Reset(); | ||||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||||
var receivedPacket = new ReceivedMqttPacket( | var receivedPacket = new ReceivedMqttPacket( | ||||
header.Flags, | header.Flags, | ||||
@@ -14,7 +14,9 @@ namespace MQTTnet.Core.Tests | |||||
[ExpectedException(typeof(MqttCommunicationClosedGracefullyException))] | [ExpectedException(typeof(MqttCommunicationClosedGracefullyException))] | ||||
public void MqttPacketReader_EmptyStream() | public void MqttPacketReader_EmptyStream() | ||||
{ | { | ||||
MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); | |||||
var fixedHeader = new byte[2]; | |||||
var singleByteBuffer = new byte[1]; | |||||
MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -199,6 +199,64 @@ namespace MQTTnet.Core.Tests | |||||
DeserializeAndCompare(p, "Ow4ABUEvQi9DAHtIRUxMTw=="); | DeserializeAndCompare(p, "Ow4ABUEvQi9DAHtIRUxMTw=="); | ||||
} | } | ||||
[TestMethod] | |||||
public void DeserializeV311_MqttPublishPacket_Qos1() | |||||
{ | |||||
var p = new MqttPublishPacket | |||||
{ | |||||
QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, | |||||
}; | |||||
var p2 = Roundtrip(p); | |||||
Assert.AreEqual(p.QualityOfServiceLevel, p2.QualityOfServiceLevel); | |||||
Assert.AreEqual(p.Dup, p2.Dup); | |||||
} | |||||
[TestMethod] | |||||
public void DeserializeV311_MqttPublishPacket_Qos2() | |||||
{ | |||||
var p = new MqttPublishPacket | |||||
{ | |||||
QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce, | |||||
PacketIdentifier = 1 | |||||
}; | |||||
var p2 = Roundtrip(p); | |||||
Assert.AreEqual(p.QualityOfServiceLevel, p2.QualityOfServiceLevel); | |||||
Assert.AreEqual(p.Dup, p2.Dup); | |||||
} | |||||
[TestMethod] | |||||
public void DeserializeV311_MqttPublishPacket_Qos3() | |||||
{ | |||||
var p = new MqttPublishPacket | |||||
{ | |||||
QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce, | |||||
PacketIdentifier = 1 | |||||
}; | |||||
var p2 = Roundtrip(p); | |||||
Assert.AreEqual(p.QualityOfServiceLevel, p2.QualityOfServiceLevel); | |||||
Assert.AreEqual(p.Dup, p2.Dup); | |||||
} | |||||
[TestMethod] | |||||
public void DeserializeV311_MqttPublishPacket_DupFalse() | |||||
{ | |||||
var p = new MqttPublishPacket | |||||
{ | |||||
Dup = false, | |||||
}; | |||||
var p2 = Roundtrip(p); | |||||
Assert.AreEqual(p.Dup, p2.Dup); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void SerializeV311_MqttPubAckPacket() | public void SerializeV311_MqttPubAckPacket() | ||||
{ | { | ||||
@@ -418,7 +476,9 @@ namespace MQTTnet.Core.Tests | |||||
using (var headerStream = new MemoryStream(Join(buffer1))) | using (var headerStream = new MemoryStream(Join(buffer1))) | ||||
{ | { | ||||
var channel = new TestMqttChannel(headerStream); | var channel = new TestMqttChannel(headerStream); | ||||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||||
var fixedHeader = new byte[2]; | |||||
var singleByteBuffer = new byte[1]; | |||||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | ||||
{ | { | ||||
@@ -430,6 +490,28 @@ namespace MQTTnet.Core.Tests | |||||
} | } | ||||
} | } | ||||
private static T Roundtrip<T>(T packet, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||||
where T : MqttBasePacket | |||||
{ | |||||
var serializer = new MqttPacketSerializer { ProtocolVersion = protocolVersion }; | |||||
var buffer1 = serializer.Serialize(packet); | |||||
using (var headerStream = new MemoryStream(Join(buffer1))) | |||||
{ | |||||
var channel = new TestMqttChannel(headerStream); | |||||
var fixedHeader = new byte[2]; | |||||
var singleByteBuffer = new byte[1]; | |||||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | |||||
{ | |||||
return (T)serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(bodyStream.ToArray(), 0))); | |||||
} | |||||
} | |||||
} | |||||
private static byte[] Join(params ArraySegment<byte>[] chunks) | private static byte[] Join(params ArraySegment<byte>[] chunks) | ||||
{ | { | ||||
var buffer = new MemoryStream(); | var buffer = new MemoryStream(); | ||||
@@ -8,6 +8,7 @@ using System.Collections.Generic; | |||||
using System.Text; | using System.Text; | ||||
using System.Threading; | using System.Threading; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using MQTTnet.Implementations; | |||||
namespace MQTTnet.Core.Tests | namespace MQTTnet.Core.Tests | ||||
{ | { | ||||
@@ -167,6 +168,72 @@ namespace MQTTnet.Core.Tests | |||||
Assert.AreEqual(1, receivedMessagesCount); | Assert.AreEqual(1, receivedMessagesCount); | ||||
} | } | ||||
[TestMethod] | |||||
public async Task MqttServer_Publish_MultipleClients() | |||||
{ | |||||
var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); | |||||
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); | |||||
var receivedMessagesCount = 0; | |||||
var locked = new object(); | |||||
var clientOptions = new MqttClientOptionsBuilder() | |||||
.WithTcpServer("localhost") | |||||
.Build(); | |||||
var clientOptions2 = new MqttClientOptionsBuilder() | |||||
.WithTcpServer("localhost") | |||||
.Build(); | |||||
try | |||||
{ | |||||
await s.StartAsync(new MqttServerOptions()); | |||||
var c1 = new MqttFactory().CreateMqttClient(); | |||||
var c2 = new MqttFactory().CreateMqttClient(); | |||||
await c1.ConnectAsync(clientOptions); | |||||
await c2.ConnectAsync(clientOptions2); | |||||
c1.ApplicationMessageReceived += (_, __) => | |||||
{ | |||||
lock (locked) | |||||
{ | |||||
receivedMessagesCount++; | |||||
} | |||||
}; | |||||
c2.ApplicationMessageReceived += (_, __) => | |||||
{ | |||||
lock (locked) | |||||
{ | |||||
receivedMessagesCount++; | |||||
} | |||||
}; | |||||
var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); | |||||
await c1.SubscribeAsync(new TopicFilter("a", MqttQualityOfServiceLevel.AtLeastOnce)); | |||||
await c2.SubscribeAsync(new TopicFilter("a", MqttQualityOfServiceLevel.AtLeastOnce)); | |||||
//await Task.WhenAll(Publish(c1, message), Publish(c2, message)); | |||||
await Publish(c1, message); | |||||
await Task.Delay(500); | |||||
} | |||||
finally | |||||
{ | |||||
await s.StopAsync(); | |||||
} | |||||
Assert.AreEqual(2000, receivedMessagesCount); | |||||
} | |||||
private static async Task Publish(IMqttClient c1, MqttApplicationMessage message) | |||||
{ | |||||
for (int i = 0; i < 1000; i++) | |||||
{ | |||||
await c1.PublishAsync(message); | |||||
} | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public async Task MqttServer_RetainedMessagesFlow() | public async Task MqttServer_RetainedMessagesFlow() | ||||
{ | { | ||||