@@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Core.Channel; | |||
using MQTTnet.Core.Client; | |||
using MQTTnet.Core.Exceptions; | |||
using System.IO; | |||
namespace MQTTnet.Implementations | |||
@@ -16,10 +15,6 @@ namespace MQTTnet.Implementations | |||
private Socket _socket; | |||
private SslStream _sslStream; | |||
public Stream RawStream { get; private set; } | |||
public Stream SendStream { get; private set; } | |||
public Stream ReceiveStream { get; private set; } | |||
/// <summary> | |||
/// called on client sockets are created in connect | |||
/// </summary> | |||
@@ -36,61 +31,61 @@ namespace MQTTnet.Implementations | |||
{ | |||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | |||
_sslStream = sslStream; | |||
CreateCommStreams(socket, sslStream); | |||
CreateStreams(socket, sslStream); | |||
} | |||
public Stream RawStream { get; private set; } | |||
public Stream SendStream { get; private set; } | |||
public Stream ReceiveStream { get; private set; } | |||
public async Task ConnectAsync(MqttClientOptions options) | |||
{ | |||
if (options == null) throw new ArgumentNullException(nameof(options)); | |||
try | |||
{ | |||
if (_socket == null) | |||
{ | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
} | |||
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false); | |||
if (options.TlsOptions.UseTls) | |||
{ | |||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||
if (_socket == null) | |||
{ | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
} | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||
} | |||
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false); | |||
CreateCommStreams(_socket, _sslStream); | |||
} | |||
catch (SocketException exception) | |||
if (options.TlsOptions.UseTls) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||
} | |||
CreateStreams(_socket, _sslStream); | |||
} | |||
public Task DisconnectAsync() | |||
{ | |||
try | |||
{ | |||
Dispose(); | |||
return Task.FromResult(0); | |||
} | |||
catch (SocketException exception) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
} | |||
Dispose(); | |||
return Task.FromResult(0); | |||
} | |||
public void Dispose() | |||
{ | |||
_socket?.Dispose(); | |||
_sslStream?.Dispose(); | |||
RawStream?.Dispose(); | |||
RawStream = null; | |||
ReceiveStream?.Dispose(); | |||
ReceiveStream = null; | |||
SendStream?.Dispose(); | |||
SendStream = null; | |||
_socket?.Dispose(); | |||
_socket = null; | |||
_sslStream?.Dispose(); | |||
_sslStream = null; | |||
} | |||
private void CreateCommStreams(Socket socket, SslStream sslStream) | |||
private void CreateStreams(Socket socket, Stream sslStream) | |||
{ | |||
RawStream = (Stream)sslStream ?? new NetworkStream(socket); | |||
RawStream = sslStream ?? new NetworkStream(socket); | |||
//cannot use this as default buffering prevents from receiving the first connect message | |||
//need two streams otherwise read and write have to be synchronized | |||
@@ -14,7 +14,6 @@ namespace MQTTnet.Implementations | |||
private ClientWebSocket _webSocket = new ClientWebSocket(); | |||
public Stream RawStream { get; private set; } | |||
public Stream SendStream => RawStream; | |||
public Stream ReceiveStream => RawStream; | |||
@@ -56,15 +56,12 @@ namespace MQTTnet.Implementations | |||
public override bool CanSeek => false; | |||
public override bool CanWrite => true; | |||
public override long Length | |||
{ | |||
get { throw new NotSupportedException(); } | |||
} | |||
public override long Length => throw new NotSupportedException(); | |||
public override long Position | |||
{ | |||
get { throw new NotSupportedException(); } | |||
set { throw new NotSupportedException(); } | |||
get => throw new NotSupportedException(); | |||
set => throw new NotSupportedException(); | |||
} | |||
public override long Seek(long offset, SeekOrigin origin) | |||
@@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates; | |||
using System.Threading.Tasks; | |||
using MQTTnet.Core.Channel; | |||
using MQTTnet.Core.Client; | |||
using MQTTnet.Core.Exceptions; | |||
using System.IO; | |||
namespace MQTTnet.Implementations | |||
@@ -16,10 +15,6 @@ namespace MQTTnet.Implementations | |||
private Socket _socket; | |||
private SslStream _sslStream; | |||
public Stream ReceiveStream { get; private set; } | |||
public Stream RawStream => ReceiveStream; | |||
public Stream SendStream => ReceiveStream; | |||
/// <summary> | |||
/// called on client sockets are created in connect | |||
/// </summary> | |||
@@ -38,55 +33,45 @@ namespace MQTTnet.Implementations | |||
ReceiveStream = (Stream)sslStream ?? new NetworkStream(socket); | |||
} | |||
public Stream ReceiveStream { get; private set; } | |||
public Stream RawStream => ReceiveStream; | |||
public Stream SendStream => ReceiveStream; | |||
public async Task ConnectAsync(MqttClientOptions options) | |||
{ | |||
if (options == null) throw new ArgumentNullException(nameof(options)); | |||
try | |||
if (_socket == null) | |||
{ | |||
if (_socket == null) | |||
{ | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
} | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
} | |||
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); | |||
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); | |||
if (options.TlsOptions.UseTls) | |||
{ | |||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||
ReceiveStream = _sslStream; | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
ReceiveStream = new NetworkStream(_socket); | |||
} | |||
if (options.TlsOptions.UseTls) | |||
{ | |||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||
ReceiveStream = _sslStream; | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||
} | |||
catch (SocketException exception) | |||
else | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
ReceiveStream = new NetworkStream(_socket); | |||
} | |||
} | |||
public Task DisconnectAsync() | |||
{ | |||
try | |||
{ | |||
Dispose(); | |||
return Task.FromResult(0); | |||
} | |||
catch (SocketException exception) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
} | |||
Dispose(); | |||
return Task.FromResult(0); | |||
} | |||
public void Dispose() | |||
{ | |||
_socket?.Dispose(); | |||
_sslStream?.Dispose(); | |||
_socket = null; | |||
_sslStream?.Dispose(); | |||
_sslStream = null; | |||
} | |||
@@ -1,15 +1,13 @@ | |||
using System; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Net.Sockets; | |||
using System.Runtime.InteropServices.WindowsRuntime; | |||
using System.Threading.Tasks; | |||
using Windows.Networking; | |||
using Windows.Networking.Sockets; | |||
using Windows.Security.Cryptography.Certificates; | |||
using Windows.Storage.Streams; | |||
using MQTTnet.Core.Channel; | |||
using MQTTnet.Core.Client; | |||
using MQTTnet.Core.Exceptions; | |||
namespace MQTTnet.Implementations | |||
{ | |||
@@ -26,89 +24,58 @@ namespace MQTTnet.Implementations | |||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | |||
} | |||
public Stream SendStream { get; private set; } | |||
public Stream ReceiveStream { get; private set; } | |||
public Stream RawStream { get; private set; } | |||
public async Task ConnectAsync(MqttClientOptions options) | |||
{ | |||
if (options == null) throw new ArgumentNullException(nameof(options)); | |||
try | |||
{ | |||
if (_socket == null) | |||
{ | |||
_socket = new StreamSocket(); | |||
} | |||
if (!options.TlsOptions.UseTls) | |||
{ | |||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString()); | |||
} | |||
else | |||
{ | |||
_socket.Control.ClientCertificate = LoadCertificate(options); | |||
if (!options.TlsOptions.CheckCertificateRevocation) | |||
{ | |||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain); | |||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing); | |||
} | |||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12); | |||
} | |||
} | |||
catch (SocketException exception) | |||
if (_socket == null) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
_socket = new StreamSocket(); | |||
} | |||
} | |||
public Task DisconnectAsync() | |||
{ | |||
try | |||
if (!options.TlsOptions.UseTls) | |||
{ | |||
Dispose(); | |||
return Task.FromResult(0); | |||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString()); | |||
} | |||
catch (SocketException exception) | |||
else | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
} | |||
} | |||
_socket.Control.ClientCertificate = LoadCertificate(options); | |||
public async Task WriteAsync(byte[] buffer) | |||
{ | |||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | |||
if (!options.TlsOptions.CheckCertificateRevocation) | |||
{ | |||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain); | |||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing); | |||
} | |||
try | |||
{ | |||
await _socket.OutputStream.WriteAsync(buffer.AsBuffer()); | |||
await _socket.OutputStream.FlushAsync(); | |||
} | |||
catch (SocketException exception) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12); | |||
} | |||
} | |||
public int Peek() | |||
{ | |||
return 0; | |||
ReceiveStream = _socket.InputStream.AsStreamForRead(); | |||
SendStream = _socket.OutputStream.AsStreamForWrite(); | |||
RawStream = ReceiveStream; | |||
} | |||
public async Task<ArraySegment<byte>> ReadAsync(int length, byte[] buffer) | |||
public Task DisconnectAsync() | |||
{ | |||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | |||
try | |||
{ | |||
var result = await _socket.InputStream.ReadAsync(buffer.AsBuffer(), (uint)buffer.Length, InputStreamOptions.None); | |||
return new ArraySegment<byte>(buffer, 0, (int)result.Length); | |||
} | |||
catch (SocketException exception) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
} | |||
Dispose(); | |||
return Task.FromResult(0); | |||
} | |||
public void Dispose() | |||
{ | |||
RawStream?.Dispose(); | |||
RawStream = null; | |||
SendStream?.Dispose(); | |||
SendStream = null; | |||
ReceiveStream?.Dispose(); | |||
ReceiveStream = null; | |||
_socket?.Dispose(); | |||
_socket = null; | |||
} | |||
@@ -14,7 +14,6 @@ namespace MQTTnet.Implementations | |||
private ClientWebSocket _webSocket = new ClientWebSocket(); | |||
public Stream RawStream { get; private set; } | |||
public Stream SendStream => RawStream; | |||
public Stream ReceiveStream => RawStream; | |||
@@ -130,6 +130,11 @@ | |||
<Version>5.3.3</Version> | |||
</PackageReference> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<Reference Include="System.Net.Security"> | |||
<HintPath>..\..\..\..\Program Files\dotnet\sdk\NuGetFallbackFolder\microsoft.netcore.app\2.0.0\ref\netcoreapp2.0\System.Net.Security.dll</HintPath> | |||
</Reference> | |||
</ItemGroup> | |||
<PropertyGroup Condition=" '$(VisualStudioVersion)' == '' or '$(VisualStudioVersion)' < '14.0' "> | |||
<VisualStudioVersion>14.0</VisualStudioVersion> | |||
</PropertyGroup> | |||
@@ -15,7 +15,6 @@ namespace MQTTnet.Core.Adapter | |||
public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter | |||
{ | |||
private readonly IMqttCommunicationChannel _channel; | |||
private readonly byte[] _readBuffer = new byte[BufferConstants.Size]; | |||
private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write | |||
@@ -29,76 +28,105 @@ namespace MQTTnet.Core.Adapter | |||
public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) | |||
{ | |||
return _channel.ConnectAsync(options).TimeoutAfter(timeout); | |||
try | |||
{ | |||
return _channel.ConnectAsync(options).TimeoutAfter(timeout); | |||
} | |||
catch (Exception exception) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
} | |||
} | |||
public Task DisconnectAsync() | |||
{ | |||
return _channel.DisconnectAsync(); | |||
try | |||
{ | |||
return _channel.DisconnectAsync(); | |||
} | |||
catch (Exception exception) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
} | |||
} | |||
public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets) | |||
{ | |||
lock (_channel) | |||
try | |||
{ | |||
foreach (var packet in packets) | |||
lock (_channel) | |||
{ | |||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); | |||
foreach (var packet in packets) | |||
{ | |||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); | |||
var writeBuffer = PacketSerializer.Serialize(packet); | |||
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); | |||
var writeBuffer = PacketSerializer.Serialize(packet); | |||
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); | |||
} | |||
} | |||
} | |||
await _sendTask; // configure await false geneates stackoverflow | |||
await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); | |||
await _sendTask; // configure await false geneates stackoverflow | |||
await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); | |||
} | |||
catch (Exception exception) | |||
{ | |||
throw new MqttCommunicationException(exception); | |||
} | |||
} | |||
public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout) | |||
{ | |||
Tuple<MqttPacketHeader, MemoryStream> tuple; | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
tuple = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); | |||
} | |||
else | |||
try | |||
{ | |||
tuple = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); | |||
} | |||
ReceivedMqttPacket receivedMqttPacket; | |||
if (timeout > TimeSpan.Zero) | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); | |||
} | |||
var packet = PacketSerializer.Deserialize(tuple.Item1, tuple.Item2); | |||
var packet = PacketSerializer.Deserialize(receivedMqttPacket); | |||
if (packet == null) | |||
{ | |||
throw new MqttProtocolViolationException("Received malformed packet."); | |||
} | |||
if (packet == null) | |||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); | |||
return packet; | |||
} | |||
catch (Exception exception) | |||
{ | |||
throw new MqttProtocolViolationException("Received malformed packet."); | |||
throw new MqttCommunicationException(exception); | |||
} | |||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); | |||
return packet; | |||
} | |||
private async Task<Tuple<MqttPacketHeader, MemoryStream>> ReceiveAsync(Stream stream) | |||
private async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream) | |||
{ | |||
var header = MqttPacketReader.ReadHeaderFromSource(stream); | |||
MemoryStream body; | |||
if (header.BodyLength > 0) | |||
if (header.BodyLength == 0) | |||
{ | |||
var totalRead = 0; | |||
do | |||
{ | |||
var read = await stream.ReadAsync(_readBuffer, totalRead, header.BodyLength - totalRead).ConfigureAwait(false); | |||
totalRead += read; | |||
} while (totalRead < header.BodyLength); | |||
body = new MemoryStream(_readBuffer, 0, header.BodyLength); | |||
return new ReceivedMqttPacket(header, new MemoryStream(0)); | |||
} | |||
else | |||
var body = new byte[header.BodyLength]; | |||
var offset = 0; | |||
do | |||
{ | |||
var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset).ConfigureAwait(false); | |||
offset += readBytesCount; | |||
} while (offset < header.BodyLength); | |||
if (offset > header.BodyLength) | |||
{ | |||
body = new MemoryStream(); | |||
throw new MqttCommunicationException($"Read more body bytes than required ({offset}/{header.BodyLength})."); | |||
} | |||
return Tuple.Create(header, body); | |||
return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length)); | |||
} | |||
} | |||
} |
@@ -0,0 +1,19 @@ | |||
using System; | |||
using System.IO; | |||
using MQTTnet.Core.Packets; | |||
namespace MQTTnet.Core.Adapter | |||
{ | |||
public class ReceivedMqttPacket | |||
{ | |||
public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body) | |||
{ | |||
Header = header ?? throw new ArgumentNullException(nameof(header)); | |||
Body = body ?? throw new ArgumentNullException(nameof(body)); | |||
} | |||
public MqttPacketHeader Header { get; } | |||
public MemoryStream Body { get; } | |||
} | |||
} |
@@ -99,6 +99,11 @@ namespace MQTTnet.Core.Client | |||
public async Task DisconnectAsync() | |||
{ | |||
if (!IsConnected) | |||
{ | |||
return; | |||
} | |||
try | |||
{ | |||
await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); | |||
@@ -11,7 +11,7 @@ namespace MQTTnet.Core.Client | |||
public class MqttPacketDispatcher | |||
{ | |||
private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>(); | |||
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>>(); | |||
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>(); | |||
public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) | |||
{ | |||
@@ -24,7 +24,7 @@ namespace MQTTnet.Core.Client | |||
} | |||
catch (MqttCommunicationTimedOutException) | |||
{ | |||
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); | |||
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet of type '{0}'.", responseType.Name); | |||
throw; | |||
} | |||
finally | |||
@@ -42,16 +42,20 @@ namespace MQTTnet.Core.Client | |||
{ | |||
if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) | |||
{ | |||
if (byid.TryRemove( withIdentifier.PacketIdentifier, out var tcs)) | |||
if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) | |||
{ | |||
tcs.TrySetResult( packet ); | |||
tcs.TrySetResult(packet); | |||
return; | |||
} | |||
} | |||
} | |||
else if (_packetByResponseType.TryRemove(type, out var tcs)) | |||
{ | |||
tcs.TrySetResult(packet); | |||
return; | |||
} | |||
throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched."); | |||
} | |||
public void Reset() | |||
@@ -25,20 +25,22 @@ namespace MQTTnet.Core.Internal | |||
try | |||
{ | |||
cancellationTokenSource.CancelAfter(timeout); | |||
#pragma warning disable 4014 | |||
task.ContinueWith(t => | |||
{ | |||
if (t.IsFaulted) | |||
{ | |||
tcs.TrySetException(t.Exception); | |||
} | |||
#pragma warning restore 4014 | |||
{ | |||
if (t.IsFaulted) | |||
{ | |||
tcs.TrySetException(t.Exception); | |||
} | |||
if (t.IsCompleted) | |||
{ | |||
tcs.TrySetResult(t.Result); | |||
} | |||
}, cancellationTokenSource.Token); | |||
if (t.IsCompleted) | |||
{ | |||
tcs.TrySetResult(t.Result); | |||
} | |||
}, cancellationTokenSource.Token); | |||
cancellationTokenSource.CancelAfter(timeout); | |||
return await tcs.Task; | |||
} | |||
catch (TaskCanceledException) | |||
@@ -1,4 +1,4 @@ | |||
using System.IO; | |||
using MQTTnet.Core.Adapter; | |||
using MQTTnet.Core.Packets; | |||
namespace MQTTnet.Core.Serializer | |||
@@ -9,6 +9,6 @@ namespace MQTTnet.Core.Serializer | |||
byte[] Serialize(MqttBasePacket mqttPacket); | |||
MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream stream); | |||
MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket); | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.IO; | |||
using System.Text; | |||
using MQTTnet.Core.Adapter; | |||
using MQTTnet.Core.Exceptions; | |||
using MQTTnet.Core.Protocol; | |||
using MQTTnet.Core.Packets; | |||
@@ -9,15 +10,29 @@ namespace MQTTnet.Core.Serializer | |||
{ | |||
public sealed class MqttPacketReader : BinaryReader | |||
{ | |||
private readonly MqttPacketHeader _header; | |||
public MqttPacketReader(Stream stream, MqttPacketHeader header) | |||
: base(stream, Encoding.UTF8, true) | |||
private readonly ReceivedMqttPacket _receivedMqttPacket; | |||
public MqttPacketReader(ReceivedMqttPacket receivedMqttPacket) | |||
: base(receivedMqttPacket.Body, Encoding.UTF8, true) | |||
{ | |||
_header = header; | |||
_receivedMqttPacket = receivedMqttPacket; | |||
} | |||
public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; | |||
public bool EndOfRemainingData => BaseStream.Position == _receivedMqttPacket.Header.BodyLength; | |||
public static MqttPacketHeader ReadHeaderFromSource(Stream stream) | |||
{ | |||
var fixedHeader = (byte)stream.ReadByte(); | |||
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); | |||
var bodyLength = ReadBodyLengthFromSource(stream); | |||
return new MqttPacketHeader | |||
{ | |||
FixedHeader = fixedHeader, | |||
ControlPacketType = controlPacketType, | |||
BodyLength = bodyLength | |||
}; | |||
} | |||
public override ushort ReadUInt16() | |||
{ | |||
@@ -44,21 +59,7 @@ namespace MQTTnet.Core.Serializer | |||
public byte[] ReadRemainingData() | |||
{ | |||
return ReadBytes(_header.BodyLength - (int)BaseStream.Position); | |||
} | |||
public static MqttPacketHeader ReadHeaderFromSource(Stream stream) | |||
{ | |||
var fixedHeader = (byte)stream.ReadByte(); | |||
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); | |||
var bodyLength = ReadBodyLengthFromSource(stream); | |||
return new MqttPacketHeader | |||
{ | |||
FixedHeader = fixedHeader, | |||
ControlPacketType = controlPacketType, | |||
BodyLength = bodyLength | |||
}; | |||
return ReadBytes(_receivedMqttPacket.Header.BodyLength - (int)BaseStream.Position); | |||
} | |||
private static int ReadBodyLengthFromSource(Stream stream) | |||
@@ -74,7 +75,7 @@ namespace MQTTnet.Core.Serializer | |||
multiplier *= 128; | |||
if (multiplier > 128 * 128 * 128) | |||
{ | |||
throw new MqttProtocolViolationException("Remaining length is ivalid."); | |||
throw new MqttProtocolViolationException("Remaining length is invalid."); | |||
} | |||
} while ((encodedByte & 128) != 0); | |||
return value; | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Text; | |||
using MQTTnet.Core.Adapter; | |||
using MQTTnet.Core.Exceptions; | |||
using MQTTnet.Core.Packets; | |||
using MQTTnet.Core.Protocol; | |||
@@ -110,14 +111,13 @@ namespace MQTTnet.Core.Serializer | |||
throw new MqttProtocolViolationException("Packet type invalid."); | |||
} | |||
public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) | |||
public MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket) | |||
{ | |||
if (header == null) throw new ArgumentNullException(nameof(header)); | |||
if (body == null) throw new ArgumentNullException(nameof(body)); | |||
if (receivedMqttPacket == null) throw new ArgumentNullException(nameof(receivedMqttPacket)); | |||
using (var reader = new MqttPacketReader(body, header)) | |||
using (var reader = new MqttPacketReader(receivedMqttPacket)) | |||
{ | |||
return Deserialize(header, reader); | |||
return Deserialize(receivedMqttPacket.Header, reader); | |||
} | |||
} | |||
@@ -3,6 +3,7 @@ using System.IO; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using MQTTnet.Core.Adapter; | |||
using MQTTnet.Core.Channel; | |||
using MQTTnet.Core.Client; | |||
using MQTTnet.Core.Packets; | |||
@@ -436,20 +437,20 @@ namespace MQTTnet.Core.Tests | |||
private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value) | |||
{ | |||
var serializer = new MqttPacketSerializer(); | |||
var buffer1 = serializer.Serialize(packet); | |||
using (var headerStream = new MemoryStream( buffer1 )) | |||
using (var headerStream = new MemoryStream(buffer1)) | |||
{ | |||
var header = MqttPacketReader.ReadHeaderFromSource( headerStream ); | |||
var header = MqttPacketReader.ReadHeaderFromSource(headerStream); | |||
using (var bodyStream = new MemoryStream( buffer1, (int)headerStream.Position, header.BodyLength )) | |||
using (var bodyStream = new MemoryStream(buffer1, (int)headerStream.Position, header.BodyLength)) | |||
{ | |||
var deserializedPacket = serializer.Deserialize(header, bodyStream); | |||
var buffer2 = serializer.Serialize( deserializedPacket ); | |||
var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket(header, bodyStream)); | |||
var buffer2 = serializer.Serialize(deserializedPacket); | |||
Assert.AreEqual( expectedBase64Value, Convert.ToBase64String( buffer2 ) ); | |||
} | |||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(buffer2)); | |||
} | |||
} | |||
} | |||
} | |||
@@ -17,18 +17,18 @@ namespace MQTTnet.TestApp.NetFramework | |||
{ | |||
public static async Task RunAsync() | |||
{ | |||
var server = Task.Run(() => RunServerAsync()); | |||
var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10))); | |||
var server = Task.Factory.StartNew(RunServerAsync, TaskCreationOptions.LongRunning); | |||
var client = Task.Factory.StartNew(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10)), TaskCreationOptions.LongRunning); | |||
await Task.WhenAll(server, client).ConfigureAwait(false); | |||
} | |||
private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval) | |||
{ | |||
return Task.WhenAll(Enumerable.Range(0, 3).Select((i) => Task.Run(() => RunClientAsync(msgChunkSize, interval)))); | |||
return Task.WhenAll(Enumerable.Range(0, 3).Select(i => Task.Run(() => RunClientAsync(msgChunkSize, interval)))); | |||
} | |||
private static async Task RunClientAsync( int msgChunkSize, TimeSpan interval ) | |||
private static async Task RunClientAsync(int msgChunkSize, TimeSpan interval) | |||
{ | |||
try | |||
{ | |||
@@ -83,7 +83,7 @@ namespace MQTTnet.TestApp.NetFramework | |||
Console.WriteLine("### WAITING FOR APPLICATION MESSAGES ###"); | |||
var testMessageCount = 1000; | |||
var testMessageCount = 10000; | |||
var message = CreateMessage(); | |||
var stopwatch = Stopwatch.StartNew(); | |||
for (var i = 0; i < testMessageCount; i++) | |||
@@ -92,8 +92,8 @@ namespace MQTTnet.TestApp.NetFramework | |||
} | |||
stopwatch.Stop(); | |||
Console.WriteLine($"Sent 1000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message)."); | |||
Console.WriteLine($"Sent 10.000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message)."); | |||
stopwatch.Restart(); | |||
var sentMessagesCount = 0; | |||
while (stopwatch.ElapsedMilliseconds < 1000) | |||
@@ -109,32 +109,32 @@ namespace MQTTnet.TestApp.NetFramework | |||
while (true) | |||
{ | |||
var msgs = Enumerable.Range( 0, msgChunkSize ) | |||
.Select( i => CreateMessage() ) | |||
var msgs = Enumerable.Range(0, msgChunkSize) | |||
.Select(i => CreateMessage()) | |||
.ToList(); | |||
if (false) | |||
{ | |||
//send concurrent (test for raceconditions) | |||
var sendTasks = msgs | |||
.Select( msg => PublishSingleMessage( client, msg, ref msgCount ) ) | |||
.Select(msg => PublishSingleMessage(client, msg, ref msgCount)) | |||
.ToList(); | |||
await Task.WhenAll( sendTasks ); | |||
await Task.WhenAll(sendTasks); | |||
} | |||
else | |||
{ | |||
await client.PublishAsync( msgs ); | |||
await client.PublishAsync(msgs); | |||
msgCount += msgs.Count; | |||
//send multiple | |||
} | |||
var now = DateTime.Now; | |||
if (last < now - TimeSpan.FromSeconds(1)) | |||
{ | |||
Console.WriteLine( $"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}" ); | |||
Console.WriteLine($"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}"); | |||
msgCount = 0; | |||
last = now; | |||
} | |||
@@ -152,19 +152,19 @@ namespace MQTTnet.TestApp.NetFramework | |||
{ | |||
return new MqttApplicationMessage( | |||
"A/B/C", | |||
Encoding.UTF8.GetBytes( "Hello World" ), | |||
Encoding.UTF8.GetBytes("Hello World"), | |||
MqttQualityOfServiceLevel.AtMostOnce, | |||
false | |||
); | |||
} | |||
private static Task PublishSingleMessage( IMqttClient client, MqttApplicationMessage applicationMessage, ref int count ) | |||
private static Task PublishSingleMessage(IMqttClient client, MqttApplicationMessage applicationMessage, ref int count) | |||
{ | |||
Interlocked.Increment( ref count ); | |||
return Task.Run( () => | |||
{ | |||
return client.PublishAsync( applicationMessage ); | |||
} ); | |||
Interlocked.Increment(ref count); | |||
return Task.Run(() => | |||
{ | |||
return client.PublishAsync(applicationMessage); | |||
}); | |||
} | |||
private static void RunServerAsync() | |||
@@ -187,19 +187,18 @@ namespace MQTTnet.TestApp.NetFramework | |||
}, | |||
DefaultCommunicationTimeout = TimeSpan.FromMinutes(10) | |||
}; | |||
var mqttServer = new MqttServerFactory().CreateMqttServer(options); | |||
var last = DateTime.Now; | |||
var msgs = 0; | |||
mqttServer.ApplicationMessageReceived += (sender, args) => | |||
var stopwatch = Stopwatch.StartNew(); | |||
mqttServer.ApplicationMessageReceived += (sender, args) => | |||
{ | |||
msgs++; | |||
var now = DateTime.Now; | |||
if (last < now - TimeSpan.FromSeconds(1)) | |||
if (stopwatch.ElapsedMilliseconds > 1000) | |||
{ | |||
Console.WriteLine($"received {msgs}"); | |||
msgs = 0; | |||
last = now; | |||
stopwatch.Restart(); | |||
} | |||
}; | |||
mqttServer.Start(); | |||
@@ -42,7 +42,7 @@ | |||
<UseVSHostingProcess>false</UseVSHostingProcess> | |||
<ErrorReport>prompt</ErrorReport> | |||
<Prefer32Bit>true</Prefer32Bit> | |||
<UseDotNetNativeToolchain>true</UseDotNetNativeToolchain> | |||
<UseDotNetNativeToolchain>false</UseDotNetNativeToolchain> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|ARM'"> | |||
<DebugSymbols>true</DebugSymbols> | |||