@@ -23,6 +23,10 @@ namespace MQTTnet.Adapter | |||
private readonly IMqttNetChildLogger _logger; | |||
private readonly IMqttChannel _channel; | |||
private readonly byte[] _fixedHeaderBuffer = new byte[2]; | |||
private readonly byte[] _singleByteBuffer = new byte[1]; | |||
private bool _isDisposed; | |||
public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetChildLogger logger) | |||
@@ -163,7 +167,7 @@ namespace MQTTnet.Adapter | |||
private async Task<ReceivedMqttPacket> ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||
{ | |||
var fixedHeader = await MqttPacketReader.ReadFixedHeaderAsync(channel, cancellationToken).ConfigureAwait(false); | |||
var fixedHeader = await MqttPacketReader.ReadFixedHeaderAsync(channel, _fixedHeaderBuffer, _singleByteBuffer, cancellationToken).ConfigureAwait(false); | |||
try | |||
{ | |||
@@ -4,7 +4,7 @@ | |||
{ | |||
public override string ToString() | |||
{ | |||
return "PubAck"; | |||
return $"PubAck [PacketIdentifier={PacketIdentifier}]"; | |||
} | |||
} | |||
} |
@@ -20,6 +20,7 @@ namespace MQTTnet.Serializer | |||
public byte ReadByte() | |||
{ | |||
ValidateReceiveBuffer(1); | |||
return _buffer[_offset++]; | |||
} | |||
@@ -30,6 +31,8 @@ namespace MQTTnet.Serializer | |||
public ushort ReadUInt16() | |||
{ | |||
ValidateReceiveBuffer(2); | |||
var msb = _buffer[_offset++]; | |||
var lsb = _buffer[_offset++]; | |||
@@ -40,12 +43,22 @@ namespace MQTTnet.Serializer | |||
{ | |||
var length = ReadUInt16(); | |||
ValidateReceiveBuffer(length); | |||
var result = new ArraySegment<byte>(_buffer, _offset, length); | |||
_offset += length; | |||
return result; | |||
} | |||
private void ValidateReceiveBuffer(ushort length) | |||
{ | |||
if (_buffer.Length < _offset + length) | |||
{ | |||
throw new ArgumentOutOfRangeException(nameof(_buffer), $"expected at least {_offset + length} bytes but there are only {_buffer.Length} bytes"); | |||
} | |||
} | |||
public string ReadStringWithLengthPrefix() | |||
{ | |||
var buffer = ReadWithLengthPrefix(); | |||
@@ -9,19 +9,13 @@ namespace MQTTnet.Serializer | |||
{ | |||
public static class MqttPacketReader | |||
{ | |||
[ThreadStatic] | |||
private static byte[] _fixedHeaderBuffer; | |||
[ThreadStatic] | |||
private static byte[] _singleByteBuffer; | |||
public static async Task<MqttFixedHeader> ReadFixedHeaderAsync(IMqttChannel channel, CancellationToken cancellationToken) | |||
public static async Task<MqttFixedHeader> ReadFixedHeaderAsync(IMqttChannel channel, byte[] fixedHeaderBuffer, byte[] singleByteBuffer, CancellationToken cancellationToken) | |||
{ | |||
// The MQTT fixed header contains 1 byte of flags and at least 1 byte for the remaining data length. | |||
// So in all cases at least 2 bytes must be read for a complete MQTT packet. | |||
// async/await is used here because the next packet is received in a couple of minutes so the performance | |||
// impact is acceptable according to a useless waiting thread. | |||
var buffer = InitializeFixedHeaderBuffer(); | |||
var buffer = fixedHeaderBuffer; | |||
var totalBytesRead = 0; | |||
while (totalBytesRead < buffer.Length) | |||
@@ -41,12 +35,12 @@ namespace MQTTnet.Serializer | |||
{ | |||
return new MqttFixedHeader(buffer[0], 0); | |||
} | |||
var bodyLength = ReadBodyLength(channel, buffer[1], cancellationToken); | |||
var bodyLength = ReadBodyLength(channel, buffer[1], singleByteBuffer, cancellationToken); | |||
return new MqttFixedHeader(buffer[0], bodyLength); | |||
} | |||
private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, CancellationToken cancellationToken) | |||
private static int ReadBodyLength(IMqttChannel channel, byte initialEncodedByte, byte[] singleByteBuffer, CancellationToken cancellationToken) | |||
{ | |||
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | |||
var multiplier = 128; | |||
@@ -61,7 +55,7 @@ namespace MQTTnet.Serializer | |||
// is too big for reading 1 byte in a row. We expect that the remaining data was sent | |||
// directly after the initial bytes. If the client disconnects just in this moment we | |||
// will get an exception anyway. | |||
encodedByte = ReadByte(channel, cancellationToken); | |||
encodedByte = ReadByte(channel, singleByteBuffer, cancellationToken); | |||
value += (byte)(encodedByte & 127) * multiplier; | |||
if (multiplier > 128 * 128 * 128) | |||
@@ -75,27 +69,16 @@ namespace MQTTnet.Serializer | |||
return value; | |||
} | |||
private static byte ReadByte(IMqttChannel channel, CancellationToken cancellationToken) | |||
private static byte ReadByte(IMqttChannel channel, byte[] singleByteBuffer, CancellationToken cancellationToken) | |||
{ | |||
var buffer = InitializeSingleByteBuffer(); | |||
var readCount = channel.ReadAsync(buffer, 0, 1, cancellationToken).GetAwaiter().GetResult(); | |||
var readCount = channel.ReadAsync(singleByteBuffer, 0, 1, cancellationToken).GetAwaiter().GetResult(); | |||
if (readCount <= 0) | |||
{ | |||
cancellationToken.ThrowIfCancellationRequested(); | |||
ExceptionHelper.ThrowGracefulSocketClose(); | |||
} | |||
return buffer[0]; | |||
} | |||
private static byte[] InitializeFixedHeaderBuffer() | |||
{ | |||
return _fixedHeaderBuffer ?? (_fixedHeaderBuffer = new byte[2]); | |||
} | |||
private static byte[] InitializeSingleByteBuffer() | |||
{ | |||
return _singleByteBuffer ?? (_singleByteBuffer = new byte[1]); | |||
return singleByteBuffer[0]; | |||
} | |||
} | |||
} |
@@ -190,7 +190,7 @@ namespace MQTTnet.Serializer | |||
var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; | |||
var qualityOfServiceLevel = (MqttQualityOfServiceLevel)(receivedMqttPacket.FixedHeader >> 1 & 0x3); | |||
var dup = (receivedMqttPacket.FixedHeader & 0x3) > 0; | |||
var dup = (receivedMqttPacket.FixedHeader & 0x8) > 0; | |||
var topic = receivedMqttPacket.Body.ReadStringWithLengthPrefix(); | |||
@@ -99,7 +99,7 @@ namespace MQTTnet.Serializer | |||
public void Reset() | |||
{ | |||
Length = 0; | |||
Length = 5; | |||
} | |||
public void Seek(int offset) | |||
@@ -47,12 +47,14 @@ namespace MQTTnet.Benchmarks | |||
public void Deserialize_10000_Messages() | |||
{ | |||
var channel = new BenchmarkMqttChannel(_serializedPacket); | |||
var fixedHeader = new byte[2]; | |||
var singleByteBuffer = new byte[1]; | |||
for (var i = 0; i < 10000; i++) | |||
{ | |||
channel.Reset(); | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||
var receivedPacket = new ReceivedMqttPacket( | |||
header.Flags, | |||
@@ -14,7 +14,9 @@ namespace MQTTnet.Core.Tests | |||
[ExpectedException(typeof(MqttCommunicationClosedGracefullyException))] | |||
public void MqttPacketReader_EmptyStream() | |||
{ | |||
MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), CancellationToken.None).GetAwaiter().GetResult(); | |||
var fixedHeader = new byte[2]; | |||
var singleByteBuffer = new byte[1]; | |||
MqttPacketReader.ReadFixedHeaderAsync(new TestMqttChannel(new MemoryStream()), fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||
} | |||
} | |||
} |
@@ -199,6 +199,64 @@ namespace MQTTnet.Core.Tests | |||
DeserializeAndCompare(p, "Ow4ABUEvQi9DAHtIRUxMTw=="); | |||
} | |||
[TestMethod] | |||
public void DeserializeV311_MqttPublishPacket_Qos1() | |||
{ | |||
var p = new MqttPublishPacket | |||
{ | |||
QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, | |||
}; | |||
var p2 = Roundtrip(p); | |||
Assert.AreEqual(p.QualityOfServiceLevel, p2.QualityOfServiceLevel); | |||
Assert.AreEqual(p.Dup, p2.Dup); | |||
} | |||
[TestMethod] | |||
public void DeserializeV311_MqttPublishPacket_Qos2() | |||
{ | |||
var p = new MqttPublishPacket | |||
{ | |||
QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce, | |||
PacketIdentifier = 1 | |||
}; | |||
var p2 = Roundtrip(p); | |||
Assert.AreEqual(p.QualityOfServiceLevel, p2.QualityOfServiceLevel); | |||
Assert.AreEqual(p.Dup, p2.Dup); | |||
} | |||
[TestMethod] | |||
public void DeserializeV311_MqttPublishPacket_Qos3() | |||
{ | |||
var p = new MqttPublishPacket | |||
{ | |||
QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce, | |||
PacketIdentifier = 1 | |||
}; | |||
var p2 = Roundtrip(p); | |||
Assert.AreEqual(p.QualityOfServiceLevel, p2.QualityOfServiceLevel); | |||
Assert.AreEqual(p.Dup, p2.Dup); | |||
} | |||
[TestMethod] | |||
public void DeserializeV311_MqttPublishPacket_DupFalse() | |||
{ | |||
var p = new MqttPublishPacket | |||
{ | |||
Dup = false, | |||
}; | |||
var p2 = Roundtrip(p); | |||
Assert.AreEqual(p.Dup, p2.Dup); | |||
} | |||
[TestMethod] | |||
public void SerializeV311_MqttPubAckPacket() | |||
{ | |||
@@ -418,7 +476,9 @@ namespace MQTTnet.Core.Tests | |||
using (var headerStream = new MemoryStream(Join(buffer1))) | |||
{ | |||
var channel = new TestMqttChannel(headerStream); | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, CancellationToken.None).GetAwaiter().GetResult(); | |||
var fixedHeader = new byte[2]; | |||
var singleByteBuffer = new byte[1]; | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | |||
{ | |||
@@ -430,6 +490,28 @@ namespace MQTTnet.Core.Tests | |||
} | |||
} | |||
private static T Roundtrip<T>(T packet, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) | |||
where T : MqttBasePacket | |||
{ | |||
var serializer = new MqttPacketSerializer { ProtocolVersion = protocolVersion }; | |||
var buffer1 = serializer.Serialize(packet); | |||
using (var headerStream = new MemoryStream(Join(buffer1))) | |||
{ | |||
var channel = new TestMqttChannel(headerStream); | |||
var fixedHeader = new byte[2]; | |||
var singleByteBuffer = new byte[1]; | |||
var header = MqttPacketReader.ReadFixedHeaderAsync(channel, fixedHeader, singleByteBuffer, CancellationToken.None).GetAwaiter().GetResult(); | |||
using (var bodyStream = new MemoryStream(Join(buffer1), (int)headerStream.Position, header.RemainingLength)) | |||
{ | |||
return (T)serializer.Deserialize(new ReceivedMqttPacket(header.Flags, new MqttPacketBodyReader(bodyStream.ToArray(), 0))); | |||
} | |||
} | |||
} | |||
private static byte[] Join(params ArraySegment<byte>[] chunks) | |||
{ | |||
var buffer = new MemoryStream(); | |||
@@ -8,6 +8,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Implementations; | |||
namespace MQTTnet.Core.Tests | |||
{ | |||
@@ -167,6 +168,72 @@ namespace MQTTnet.Core.Tests | |||
Assert.AreEqual(1, receivedMessagesCount); | |||
} | |||
[TestMethod] | |||
public async Task MqttServer_Publish_MultipleClients() | |||
{ | |||
var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); | |||
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); | |||
var receivedMessagesCount = 0; | |||
var locked = new object(); | |||
var clientOptions = new MqttClientOptionsBuilder() | |||
.WithTcpServer("localhost") | |||
.Build(); | |||
var clientOptions2 = new MqttClientOptionsBuilder() | |||
.WithTcpServer("localhost") | |||
.Build(); | |||
try | |||
{ | |||
await s.StartAsync(new MqttServerOptions()); | |||
var c1 = new MqttFactory().CreateMqttClient(); | |||
var c2 = new MqttFactory().CreateMqttClient(); | |||
await c1.ConnectAsync(clientOptions); | |||
await c2.ConnectAsync(clientOptions2); | |||
c1.ApplicationMessageReceived += (_, __) => | |||
{ | |||
lock (locked) | |||
{ | |||
receivedMessagesCount++; | |||
} | |||
}; | |||
c2.ApplicationMessageReceived += (_, __) => | |||
{ | |||
lock (locked) | |||
{ | |||
receivedMessagesCount++; | |||
} | |||
}; | |||
var message = new MqttApplicationMessageBuilder().WithTopic("a").WithAtLeastOnceQoS().Build(); | |||
await c1.SubscribeAsync(new TopicFilter("a", MqttQualityOfServiceLevel.AtLeastOnce)); | |||
await c2.SubscribeAsync(new TopicFilter("a", MqttQualityOfServiceLevel.AtLeastOnce)); | |||
//await Task.WhenAll(Publish(c1, message), Publish(c2, message)); | |||
await Publish(c1, message); | |||
await Task.Delay(500); | |||
} | |||
finally | |||
{ | |||
await s.StopAsync(); | |||
} | |||
Assert.AreEqual(2000, receivedMessagesCount); | |||
} | |||
private static async Task Publish(IMqttClient c1, MqttApplicationMessage message) | |||
{ | |||
for (int i = 0; i < 1000; i++) | |||
{ | |||
await c1.PublishAsync(message); | |||
} | |||
} | |||
[TestMethod] | |||
public async Task MqttServer_RetainedMessagesFlow() | |||
{ | |||