fixed todo in MqttConnectionContextrelease/3.x.x
@@ -17,6 +17,7 @@ | |||
<dependency id="MQTTnet" version="$nugetVersion" /> | |||
<dependency id="Microsoft.AspNetCore.Connections.Abstractions" version="2.1.3" /> | |||
<dependency id="Microsoft.AspNetCore.Http.Connections" version="1.0.3" /> | |||
<dependency id="Microsoft.AspNetCore.WebSockets" version="2.1.1" /> | |||
<dependency id="Microsoft.Extensions.Hosting.Abstractions" version="2.1.1" /> | |||
</dependencies> | |||
@@ -16,6 +16,7 @@ | |||
<ItemGroup> | |||
<PackageReference Include="Microsoft.AspNetCore.Connections.Abstractions" Version="2.1.3" /> | |||
<PackageReference Include="Microsoft.AspNetCore.Http.Connections" Version="1.0.3" /> | |||
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="2.1.1" /> | |||
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" Version="2.1.1" /> | |||
</ItemGroup> | |||
@@ -1,13 +1,14 @@ | |||
using Microsoft.AspNetCore.Connections; | |||
using Microsoft.AspNetCore.Http.Connections.Features; | |||
using MQTTnet.Adapter; | |||
using MQTTnet.AspNetCore.Client.Tcp; | |||
using MQTTnet.Exceptions; | |||
using MQTTnet.Formatter; | |||
using MQTTnet.Packets; | |||
using System; | |||
using System.IO.Pipelines; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Formatter; | |||
namespace MQTTnet.AspNetCore | |||
{ | |||
@@ -32,14 +33,28 @@ namespace MQTTnet.AspNetCore | |||
private PipeWriter _output; | |||
private readonly SpanBasedMqttPacketBodyReader _reader; | |||
public string Endpoint => Connection.ConnectionId; | |||
public bool IsSecureConnection => false; // TODO: Fix detection (WS vs. WSS). | |||
public string Endpoint | |||
{ | |||
get { | |||
var connection = Http?.HttpContext?.Connection; | |||
if (connection == null) | |||
{ | |||
return Connection.ConnectionId; | |||
} | |||
return $"{connection.RemoteIpAddress}:{connection.RemotePort}"; | |||
} | |||
} | |||
public bool IsSecureConnection => Http?.HttpContext?.Request?.IsHttps ?? false; | |||
private IHttpContextFeature Http => Connection.Features.Get<IHttpContextFeature>(); | |||
public ConnectionContext Connection { get; } | |||
public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } | |||
public long BytesSent { get; } // TODO: Fix calculation. | |||
public long BytesReceived { get; } // TODO: Fix calculation. | |||
public long BytesSent { get; set; } | |||
public long BytesReceived { get; set; } | |||
public Action ReadingPacketStartedCallback { get; set; } | |||
public Action ReadingPacketCompletedCallback { get; set; } | |||
@@ -93,8 +108,9 @@ namespace MQTTnet.AspNetCore | |||
{ | |||
if (!buffer.IsEmpty) | |||
{ | |||
if (PacketFormatterAdapter.TryDecode(_reader, buffer, out var packet, out consumed, out observed)) | |||
if (PacketFormatterAdapter.TryDecode(_reader, buffer, out var packet, out consumed, out observed, out var received)) | |||
{ | |||
BytesReceived += received; | |||
return packet; | |||
} | |||
else | |||
@@ -138,6 +154,7 @@ namespace MQTTnet.AspNetCore | |||
var msg = buffer.AsMemory(); | |||
var output = _output; | |||
msg.CopyTo(output.GetMemory(msg.Length)); | |||
BytesSent += msg.Length; | |||
PacketFormatterAdapter.FreeBuffer(); | |||
output.Advance(msg.Length); | |||
await output.FlushAsync().ConfigureAwait(false); | |||
@@ -9,13 +9,20 @@ namespace MQTTnet.AspNetCore | |||
{ | |||
public static class ReaderExtensions | |||
{ | |||
public static bool TryDecode(this MqttPacketFormatterAdapter formatter, SpanBasedMqttPacketBodyReader reader, in ReadOnlySequence<byte> input, out MqttBasePacket packet, out SequencePosition consumed, out SequencePosition observed) | |||
public static bool TryDecode(this MqttPacketFormatterAdapter formatter, | |||
SpanBasedMqttPacketBodyReader reader, | |||
in ReadOnlySequence<byte> input, | |||
out MqttBasePacket packet, | |||
out SequencePosition consumed, | |||
out SequencePosition observed, | |||
out int bytesRead) | |||
{ | |||
if (formatter == null) throw new ArgumentNullException(nameof(formatter)); | |||
packet = null; | |||
consumed = input.Start; | |||
observed = input.End; | |||
bytesRead = 0; | |||
var copy = input; | |||
if (copy.Length < 2) | |||
@@ -24,7 +31,7 @@ namespace MQTTnet.AspNetCore | |||
} | |||
var fixedheader = copy.First.Span[0]; | |||
if (!TryReadBodyLength(ref copy, out var bodyLength)) | |||
if (!TryReadBodyLength(ref copy, out int headerLength, out var bodyLength)) | |||
{ | |||
return false; | |||
} | |||
@@ -48,6 +55,7 @@ namespace MQTTnet.AspNetCore | |||
packet = formatter.Decode(receivedMqttPacket); | |||
consumed = bodySlice.End; | |||
observed = bodySlice.End; | |||
bytesRead = headerLength + bodyLength; | |||
return true; | |||
} | |||
@@ -62,15 +70,16 @@ namespace MQTTnet.AspNetCore | |||
return input.ToArray(); | |||
} | |||
private static bool TryReadBodyLength(ref ReadOnlySequence<byte> input, out int result) | |||
private static bool TryReadBodyLength(ref ReadOnlySequence<byte> input, out int headerLength, out int bodyLength) | |||
{ | |||
// Alorithm taken from https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html. | |||
var multiplier = 1; | |||
var value = 0; | |||
byte encodedByte; | |||
var index = 1; | |||
result = 0; | |||
headerLength = 0; | |||
bodyLength = 0; | |||
var temp = input.Slice(0, Math.Min(5, input.Length)).GetMemory(); | |||
do | |||
@@ -93,7 +102,8 @@ namespace MQTTnet.AspNetCore | |||
input = input.Slice(index); | |||
result = value; | |||
headerLength = index; | |||
bodyLength = value; | |||
return true; | |||
} | |||
} | |||
@@ -23,24 +23,25 @@ namespace MQTTnet.AspNetCore.Tests | |||
var consumed = part.Start; | |||
var observed = part.Start; | |||
var result = false; | |||
var read = 0; | |||
var reader = new SpanBasedMqttPacketBodyReader(); | |||
part = sequence.Slice(sequence.Start, 0); // empty message should fail | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed); | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed, out read); | |||
Assert.IsFalse(result); | |||
part = sequence.Slice(sequence.Start, 1); // partial fixed header should fail | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed); | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed, out read); | |||
Assert.IsFalse(result); | |||
part = sequence.Slice(sequence.Start, 4); // partial body should fail | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed); | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed, out read); | |||
Assert.IsFalse(result); | |||
part = sequence; // complete msg should work | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed); | |||
result = serializer.TryDecode(reader, part, out packet, out consumed, out observed, out read); | |||
Assert.IsTrue(result); | |||
} | |||
} | |||