Browse Source

Refactor latest changes

release/3.x.x
Christian Kratky 7 years ago
parent
commit
76105de4c7
18 changed files with 268 additions and 262 deletions
  1. +31
    -36
      Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs
  2. +0
    -1
      Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs
  3. +3
    -6
      Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs
  4. +19
    -34
      Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs
  5. +32
    -65
      Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs
  6. +0
    -1
      Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs
  7. +5
    -0
      Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj
  8. +67
    -39
      MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs
  9. +19
    -0
      MQTTnet.Core/Adapter/ReceivedMqttPacket.cs
  10. +5
    -0
      MQTTnet.Core/Client/MqttClient.cs
  11. +8
    -4
      MQTTnet.Core/Client/MqttPacketDispatcher.cs
  12. +13
    -11
      MQTTnet.Core/Internal/TaskExtensions.cs
  13. +2
    -2
      MQTTnet.Core/Serializer/IMqttPacketSerializer.cs
  14. +23
    -22
      MQTTnet.Core/Serializer/MqttPacketReader.cs
  15. +5
    -5
      MQTTnet.Core/Serializer/MqttPacketSerializer.cs
  16. +9
    -8
      Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs
  17. +26
    -27
      Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs
  18. +1
    -1
      Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj

+ 31
- 36
Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs View File

@@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Core.Channel; using MQTTnet.Core.Channel;
using MQTTnet.Core.Client; using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions;
using System.IO; using System.IO;


namespace MQTTnet.Implementations namespace MQTTnet.Implementations
@@ -16,10 +15,6 @@ namespace MQTTnet.Implementations
private Socket _socket; private Socket _socket;
private SslStream _sslStream; private SslStream _sslStream;


public Stream RawStream { get; private set; }
public Stream SendStream { get; private set; }
public Stream ReceiveStream { get; private set; }

/// <summary> /// <summary>
/// called on client sockets are created in connect /// called on client sockets are created in connect
/// </summary> /// </summary>
@@ -36,61 +31,61 @@ namespace MQTTnet.Implementations
{ {
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); _socket = socket ?? throw new ArgumentNullException(nameof(socket));
_sslStream = sslStream; _sslStream = sslStream;
CreateCommStreams(socket, sslStream);
CreateStreams(socket, sslStream);
} }


public Stream RawStream { get; private set; }
public Stream SendStream { get; private set; }
public Stream ReceiveStream { get; private set; }

public async Task ConnectAsync(MqttClientOptions options) public async Task ConnectAsync(MqttClientOptions options)
{ {
if (options == null) throw new ArgumentNullException(nameof(options)); if (options == null) throw new ArgumentNullException(nameof(options));
try
{
if (_socket == null)
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}

await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false);


if (options.TlsOptions.UseTls)
{
_sslStream = new SslStream(new NetworkStream(_socket, true));
if (_socket == null)
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}


await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false);
}
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false);


CreateCommStreams(_socket, _sslStream);
}
catch (SocketException exception)
if (options.TlsOptions.UseTls)
{ {
throw new MqttCommunicationException(exception);
_sslStream = new SslStream(new NetworkStream(_socket, true));

await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false);
} }

CreateStreams(_socket, _sslStream);
} }


public Task DisconnectAsync() public Task DisconnectAsync()
{ {
try
{
Dispose();
return Task.FromResult(0);
}
catch (SocketException exception)
{
throw new MqttCommunicationException(exception);
}
Dispose();
return Task.FromResult(0);
} }


public void Dispose() public void Dispose()
{ {
_socket?.Dispose();
_sslStream?.Dispose();
RawStream?.Dispose();
RawStream = null;

ReceiveStream?.Dispose();
ReceiveStream = null;

SendStream?.Dispose();
SendStream = null;


_socket?.Dispose();
_socket = null; _socket = null;

_sslStream?.Dispose();
_sslStream = null; _sslStream = null;
} }


private void CreateCommStreams(Socket socket, SslStream sslStream)
private void CreateStreams(Socket socket, Stream sslStream)
{ {
RawStream = (Stream)sslStream ?? new NetworkStream(socket);
RawStream = sslStream ?? new NetworkStream(socket);


//cannot use this as default buffering prevents from receiving the first connect message //cannot use this as default buffering prevents from receiving the first connect message
//need two streams otherwise read and write have to be synchronized //need two streams otherwise read and write have to be synchronized


+ 0
- 1
Frameworks/MQTTnet.NetFramework/Implementations/MqttWebSocketChannel.cs View File

@@ -14,7 +14,6 @@ namespace MQTTnet.Implementations
private ClientWebSocket _webSocket = new ClientWebSocket(); private ClientWebSocket _webSocket = new ClientWebSocket();


public Stream RawStream { get; private set; } public Stream RawStream { get; private set; }

public Stream SendStream => RawStream; public Stream SendStream => RawStream;
public Stream ReceiveStream => RawStream; public Stream ReceiveStream => RawStream;




+ 3
- 6
Frameworks/MQTTnet.NetFramework/Implementations/WebSocketStream.cs View File

@@ -56,15 +56,12 @@ namespace MQTTnet.Implementations
public override bool CanSeek => false; public override bool CanSeek => false;
public override bool CanWrite => true; public override bool CanWrite => true;


public override long Length
{
get { throw new NotSupportedException(); }
}
public override long Length => throw new NotSupportedException();


public override long Position public override long Position
{ {
get { throw new NotSupportedException(); }
set { throw new NotSupportedException(); }
get => throw new NotSupportedException();
set => throw new NotSupportedException();
} }


public override long Seek(long offset, SeekOrigin origin) public override long Seek(long offset, SeekOrigin origin)


+ 19
- 34
Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs View File

@@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks; using System.Threading.Tasks;
using MQTTnet.Core.Channel; using MQTTnet.Core.Channel;
using MQTTnet.Core.Client; using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions;
using System.IO; using System.IO;


namespace MQTTnet.Implementations namespace MQTTnet.Implementations
@@ -16,10 +15,6 @@ namespace MQTTnet.Implementations
private Socket _socket; private Socket _socket;
private SslStream _sslStream; private SslStream _sslStream;
public Stream ReceiveStream { get; private set; }
public Stream RawStream => ReceiveStream;
public Stream SendStream => ReceiveStream;

/// <summary> /// <summary>
/// called on client sockets are created in connect /// called on client sockets are created in connect
/// </summary> /// </summary>
@@ -38,55 +33,45 @@ namespace MQTTnet.Implementations
ReceiveStream = (Stream)sslStream ?? new NetworkStream(socket); ReceiveStream = (Stream)sslStream ?? new NetworkStream(socket);
} }


public Stream ReceiveStream { get; private set; }
public Stream RawStream => ReceiveStream;
public Stream SendStream => ReceiveStream;

public async Task ConnectAsync(MqttClientOptions options) public async Task ConnectAsync(MqttClientOptions options)
{ {
if (options == null) throw new ArgumentNullException(nameof(options)); if (options == null) throw new ArgumentNullException(nameof(options));


try
if (_socket == null)
{ {
if (_socket == null)
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}


await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false);
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false);


if (options.TlsOptions.UseTls)
{
_sslStream = new SslStream(new NetworkStream(_socket, true));
ReceiveStream = _sslStream;
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false);
}
else
{
ReceiveStream = new NetworkStream(_socket);
}
if (options.TlsOptions.UseTls)
{
_sslStream = new SslStream(new NetworkStream(_socket, true));
ReceiveStream = _sslStream;
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false);
} }
catch (SocketException exception)
else
{ {
throw new MqttCommunicationException(exception);
ReceiveStream = new NetworkStream(_socket);
} }
} }


public Task DisconnectAsync() public Task DisconnectAsync()
{ {
try
{
Dispose();
return Task.FromResult(0);
}
catch (SocketException exception)
{
throw new MqttCommunicationException(exception);
}
Dispose();
return Task.FromResult(0);
} }


public void Dispose() public void Dispose()
{ {
_socket?.Dispose(); _socket?.Dispose();
_sslStream?.Dispose();

_socket = null; _socket = null;

_sslStream?.Dispose();
_sslStream = null; _sslStream = null;
} }




+ 32
- 65
Frameworks/MQTTnet.UniversalWindows/Implementations/MqttTcpChannel.cs View File

@@ -1,15 +1,13 @@
using System; using System;
using System.IO;
using System.Linq; using System.Linq;
using System.Net.Sockets;
using System.Runtime.InteropServices.WindowsRuntime; using System.Runtime.InteropServices.WindowsRuntime;
using System.Threading.Tasks; using System.Threading.Tasks;
using Windows.Networking; using Windows.Networking;
using Windows.Networking.Sockets; using Windows.Networking.Sockets;
using Windows.Security.Cryptography.Certificates; using Windows.Security.Cryptography.Certificates;
using Windows.Storage.Streams;
using MQTTnet.Core.Channel; using MQTTnet.Core.Channel;
using MQTTnet.Core.Client; using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions;


namespace MQTTnet.Implementations namespace MQTTnet.Implementations
{ {
@@ -26,89 +24,58 @@ namespace MQTTnet.Implementations
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); _socket = socket ?? throw new ArgumentNullException(nameof(socket));
} }


public Stream SendStream { get; private set; }
public Stream ReceiveStream { get; private set; }
public Stream RawStream { get; private set; }

public async Task ConnectAsync(MqttClientOptions options) public async Task ConnectAsync(MqttClientOptions options)
{ {
if (options == null) throw new ArgumentNullException(nameof(options)); if (options == null) throw new ArgumentNullException(nameof(options));
try
{
if (_socket == null)
{
_socket = new StreamSocket();
}

if (!options.TlsOptions.UseTls)
{
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString());
}
else
{
_socket.Control.ClientCertificate = LoadCertificate(options);

if (!options.TlsOptions.CheckCertificateRevocation)
{
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain);
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing);
}


await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12);
}
}
catch (SocketException exception)
if (_socket == null)
{ {
throw new MqttCommunicationException(exception);
_socket = new StreamSocket();
} }
}


public Task DisconnectAsync()
{
try
if (!options.TlsOptions.UseTls)
{ {
Dispose();
return Task.FromResult(0);
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString());
} }
catch (SocketException exception)
else
{ {
throw new MqttCommunicationException(exception);
}
}
_socket.Control.ClientCertificate = LoadCertificate(options);


public async Task WriteAsync(byte[] buffer)
{
if (buffer == null) throw new ArgumentNullException(nameof(buffer));
if (!options.TlsOptions.CheckCertificateRevocation)
{
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain);
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing);
}


try
{
await _socket.OutputStream.WriteAsync(buffer.AsBuffer());
await _socket.OutputStream.FlushAsync();
}
catch (SocketException exception)
{
throw new MqttCommunicationException(exception);
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12);
} }
}


public int Peek()
{
return 0;
ReceiveStream = _socket.InputStream.AsStreamForRead();
SendStream = _socket.OutputStream.AsStreamForWrite();
RawStream = ReceiveStream;
} }


public async Task<ArraySegment<byte>> ReadAsync(int length, byte[] buffer)
public Task DisconnectAsync()
{ {
if (buffer == null) throw new ArgumentNullException(nameof(buffer));

try
{
var result = await _socket.InputStream.ReadAsync(buffer.AsBuffer(), (uint)buffer.Length, InputStreamOptions.None);
return new ArraySegment<byte>(buffer, 0, (int)result.Length);
}
catch (SocketException exception)
{
throw new MqttCommunicationException(exception);
}
Dispose();
return Task.FromResult(0);
} }


public void Dispose() public void Dispose()
{ {
RawStream?.Dispose();
RawStream = null;

SendStream?.Dispose();
SendStream = null;

ReceiveStream?.Dispose();
ReceiveStream = null;

_socket?.Dispose(); _socket?.Dispose();
_socket = null; _socket = null;
} }


+ 0
- 1
Frameworks/MQTTnet.UniversalWindows/Implementations/MqttWebSocketChannel.cs View File

@@ -14,7 +14,6 @@ namespace MQTTnet.Implementations
private ClientWebSocket _webSocket = new ClientWebSocket(); private ClientWebSocket _webSocket = new ClientWebSocket();


public Stream RawStream { get; private set; } public Stream RawStream { get; private set; }

public Stream SendStream => RawStream; public Stream SendStream => RawStream;
public Stream ReceiveStream => RawStream; public Stream ReceiveStream => RawStream;




+ 5
- 0
Frameworks/MQTTnet.UniversalWindows/MQTTnet.UniversalWindows.csproj View File

@@ -130,6 +130,11 @@
<Version>5.3.3</Version> <Version>5.3.3</Version>
</PackageReference> </PackageReference>
</ItemGroup> </ItemGroup>
<ItemGroup>
<Reference Include="System.Net.Security">
<HintPath>..\..\..\..\Program Files\dotnet\sdk\NuGetFallbackFolder\microsoft.netcore.app\2.0.0\ref\netcoreapp2.0\System.Net.Security.dll</HintPath>
</Reference>
</ItemGroup>
<PropertyGroup Condition=" '$(VisualStudioVersion)' == '' or '$(VisualStudioVersion)' &lt; '14.0' "> <PropertyGroup Condition=" '$(VisualStudioVersion)' == '' or '$(VisualStudioVersion)' &lt; '14.0' ">
<VisualStudioVersion>14.0</VisualStudioVersion> <VisualStudioVersion>14.0</VisualStudioVersion>
</PropertyGroup> </PropertyGroup>


+ 67
- 39
MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs View File

@@ -15,7 +15,6 @@ namespace MQTTnet.Core.Adapter
public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter
{ {
private readonly IMqttCommunicationChannel _channel; private readonly IMqttCommunicationChannel _channel;
private readonly byte[] _readBuffer = new byte[BufferConstants.Size];


private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write


@@ -29,76 +28,105 @@ namespace MQTTnet.Core.Adapter


public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout)
{ {
return _channel.ConnectAsync(options).TimeoutAfter(timeout);
try
{
return _channel.ConnectAsync(options).TimeoutAfter(timeout);
}
catch (Exception exception)
{
throw new MqttCommunicationException(exception);
}
} }


public Task DisconnectAsync() public Task DisconnectAsync()
{ {
return _channel.DisconnectAsync();
try
{
return _channel.DisconnectAsync();
}
catch (Exception exception)
{
throw new MqttCommunicationException(exception);
}
} }


public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets) public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets)
{ {
lock (_channel)
try
{ {
foreach (var packet in packets)
lock (_channel)
{ {
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout);
foreach (var packet in packets)
{
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout);


var writeBuffer = PacketSerializer.Serialize(packet);
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length));
var writeBuffer = PacketSerializer.Serialize(packet);
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length));
}
} }
}


await _sendTask; // configure await false geneates stackoverflow
await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false);
await _sendTask; // configure await false geneates stackoverflow
await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false);
}
catch (Exception exception)
{
throw new MqttCommunicationException(exception);
}
} }


public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout) public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout)
{ {
Tuple<MqttPacketHeader, MemoryStream> tuple;
if (timeout > TimeSpan.Zero)
{
tuple = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false);
}
else
try
{ {
tuple = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false);
}
ReceivedMqttPacket receivedMqttPacket;
if (timeout > TimeSpan.Zero)
{
receivedMqttPacket = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false);
}
else
{
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false);
}


var packet = PacketSerializer.Deserialize(tuple.Item1, tuple.Item2);
var packet = PacketSerializer.Deserialize(receivedMqttPacket);
if (packet == null)
{
throw new MqttProtocolViolationException("Received malformed packet.");
}


if (packet == null)
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet);
return packet;
}
catch (Exception exception)
{ {
throw new MqttProtocolViolationException("Received malformed packet.");
throw new MqttCommunicationException(exception);
} }

MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet);
return packet;
} }


private async Task<Tuple<MqttPacketHeader, MemoryStream>> ReceiveAsync(Stream stream)
private async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream)
{ {
var header = MqttPacketReader.ReadHeaderFromSource(stream); var header = MqttPacketReader.ReadHeaderFromSource(stream);


MemoryStream body;
if (header.BodyLength > 0)
if (header.BodyLength == 0)
{ {
var totalRead = 0;
do
{
var read = await stream.ReadAsync(_readBuffer, totalRead, header.BodyLength - totalRead).ConfigureAwait(false);
totalRead += read;
} while (totalRead < header.BodyLength);
body = new MemoryStream(_readBuffer, 0, header.BodyLength);
return new ReceivedMqttPacket(header, new MemoryStream(0));
} }
else

var body = new byte[header.BodyLength];

var offset = 0;
do
{
var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset).ConfigureAwait(false);
offset += readBytesCount;
} while (offset < header.BodyLength);

if (offset > header.BodyLength)
{ {
body = new MemoryStream();
throw new MqttCommunicationException($"Read more body bytes than required ({offset}/{header.BodyLength}).");
} }


return Tuple.Create(header, body);
return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length));
} }
} }
} }

+ 19
- 0
MQTTnet.Core/Adapter/ReceivedMqttPacket.cs View File

@@ -0,0 +1,19 @@
using System;
using System.IO;
using MQTTnet.Core.Packets;

namespace MQTTnet.Core.Adapter
{
public class ReceivedMqttPacket
{
public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body)
{
Header = header ?? throw new ArgumentNullException(nameof(header));
Body = body ?? throw new ArgumentNullException(nameof(body));
}

public MqttPacketHeader Header { get; }

public MemoryStream Body { get; }
}
}

+ 5
- 0
MQTTnet.Core/Client/MqttClient.cs View File

@@ -99,6 +99,11 @@ namespace MQTTnet.Core.Client


public async Task DisconnectAsync() public async Task DisconnectAsync()
{ {
if (!IsConnected)
{
return;
}

try try
{ {
await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false);


+ 8
- 4
MQTTnet.Core/Client/MqttPacketDispatcher.cs View File

@@ -11,7 +11,7 @@ namespace MQTTnet.Core.Client
public class MqttPacketDispatcher public class MqttPacketDispatcher
{ {
private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>(); private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>();
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>>();
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>();


public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout)
{ {
@@ -24,7 +24,7 @@ namespace MQTTnet.Core.Client
} }
catch (MqttCommunicationTimedOutException) catch (MqttCommunicationTimedOutException)
{ {
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet.");
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet of type '{0}'.", responseType.Name);
throw; throw;
} }
finally finally
@@ -42,16 +42,20 @@ namespace MQTTnet.Core.Client
{ {
if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid))
{ {
if (byid.TryRemove( withIdentifier.PacketIdentifier, out var tcs))
if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs))
{ {
tcs.TrySetResult( packet );
tcs.TrySetResult(packet);
return;
} }
} }
} }
else if (_packetByResponseType.TryRemove(type, out var tcs)) else if (_packetByResponseType.TryRemove(type, out var tcs))
{ {
tcs.TrySetResult(packet); tcs.TrySetResult(packet);
return;
} }

throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched.");
} }


public void Reset() public void Reset()


+ 13
- 11
MQTTnet.Core/Internal/TaskExtensions.cs View File

@@ -25,20 +25,22 @@ namespace MQTTnet.Core.Internal


try try
{ {
cancellationTokenSource.CancelAfter(timeout);
#pragma warning disable 4014
task.ContinueWith(t => task.ContinueWith(t =>
{
if (t.IsFaulted)
{
tcs.TrySetException(t.Exception);
}
#pragma warning restore 4014
{
if (t.IsFaulted)
{
tcs.TrySetException(t.Exception);
}


if (t.IsCompleted)
{
tcs.TrySetResult(t.Result);
}
}, cancellationTokenSource.Token);
if (t.IsCompleted)
{
tcs.TrySetResult(t.Result);
}
}, cancellationTokenSource.Token);


cancellationTokenSource.CancelAfter(timeout);
return await tcs.Task; return await tcs.Task;
} }
catch (TaskCanceledException) catch (TaskCanceledException)


+ 2
- 2
MQTTnet.Core/Serializer/IMqttPacketSerializer.cs View File

@@ -1,4 +1,4 @@
using System.IO;
using MQTTnet.Core.Adapter;
using MQTTnet.Core.Packets; using MQTTnet.Core.Packets;


namespace MQTTnet.Core.Serializer namespace MQTTnet.Core.Serializer
@@ -9,6 +9,6 @@ namespace MQTTnet.Core.Serializer


byte[] Serialize(MqttBasePacket mqttPacket); byte[] Serialize(MqttBasePacket mqttPacket);


MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream stream);
MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket);
} }
} }

+ 23
- 22
MQTTnet.Core/Serializer/MqttPacketReader.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.IO; using System.IO;
using System.Text; using System.Text;
using MQTTnet.Core.Adapter;
using MQTTnet.Core.Exceptions; using MQTTnet.Core.Exceptions;
using MQTTnet.Core.Protocol; using MQTTnet.Core.Protocol;
using MQTTnet.Core.Packets; using MQTTnet.Core.Packets;
@@ -9,15 +10,29 @@ namespace MQTTnet.Core.Serializer
{ {
public sealed class MqttPacketReader : BinaryReader public sealed class MqttPacketReader : BinaryReader
{ {
private readonly MqttPacketHeader _header;
public MqttPacketReader(Stream stream, MqttPacketHeader header)
: base(stream, Encoding.UTF8, true)
private readonly ReceivedMqttPacket _receivedMqttPacket;
public MqttPacketReader(ReceivedMqttPacket receivedMqttPacket)
: base(receivedMqttPacket.Body, Encoding.UTF8, true)
{ {
_header = header;
_receivedMqttPacket = receivedMqttPacket;
} }


public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength;
public bool EndOfRemainingData => BaseStream.Position == _receivedMqttPacket.Header.BodyLength;

public static MqttPacketHeader ReadHeaderFromSource(Stream stream)
{
var fixedHeader = (byte)stream.ReadByte();
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4);
var bodyLength = ReadBodyLengthFromSource(stream);

return new MqttPacketHeader
{
FixedHeader = fixedHeader,
ControlPacketType = controlPacketType,
BodyLength = bodyLength
};
}


public override ushort ReadUInt16() public override ushort ReadUInt16()
{ {
@@ -44,21 +59,7 @@ namespace MQTTnet.Core.Serializer


public byte[] ReadRemainingData() public byte[] ReadRemainingData()
{ {
return ReadBytes(_header.BodyLength - (int)BaseStream.Position);
}

public static MqttPacketHeader ReadHeaderFromSource(Stream stream)
{
var fixedHeader = (byte)stream.ReadByte();
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4);
var bodyLength = ReadBodyLengthFromSource(stream);

return new MqttPacketHeader
{
FixedHeader = fixedHeader,
ControlPacketType = controlPacketType,
BodyLength = bodyLength
};
return ReadBytes(_receivedMqttPacket.Header.BodyLength - (int)BaseStream.Position);
} }


private static int ReadBodyLengthFromSource(Stream stream) private static int ReadBodyLengthFromSource(Stream stream)
@@ -74,7 +75,7 @@ namespace MQTTnet.Core.Serializer
multiplier *= 128; multiplier *= 128;
if (multiplier > 128 * 128 * 128) if (multiplier > 128 * 128 * 128)
{ {
throw new MqttProtocolViolationException("Remaining length is ivalid.");
throw new MqttProtocolViolationException("Remaining length is invalid.");
} }
} while ((encodedByte & 128) != 0); } while ((encodedByte & 128) != 0);
return value; return value;


+ 5
- 5
MQTTnet.Core/Serializer/MqttPacketSerializer.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using MQTTnet.Core.Adapter;
using MQTTnet.Core.Exceptions; using MQTTnet.Core.Exceptions;
using MQTTnet.Core.Packets; using MQTTnet.Core.Packets;
using MQTTnet.Core.Protocol; using MQTTnet.Core.Protocol;
@@ -110,14 +111,13 @@ namespace MQTTnet.Core.Serializer
throw new MqttProtocolViolationException("Packet type invalid."); throw new MqttProtocolViolationException("Packet type invalid.");
} }


public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body)
public MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket)
{ {
if (header == null) throw new ArgumentNullException(nameof(header));
if (body == null) throw new ArgumentNullException(nameof(body));
if (receivedMqttPacket == null) throw new ArgumentNullException(nameof(receivedMqttPacket));


using (var reader = new MqttPacketReader(body, header))
using (var reader = new MqttPacketReader(receivedMqttPacket))
{ {
return Deserialize(header, reader);
return Deserialize(receivedMqttPacket.Header, reader);
} }
} }




+ 9
- 8
Tests/MQTTnet.Core.Tests/MqttPacketSerializerTests.cs View File

@@ -3,6 +3,7 @@ using System.IO;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Core.Adapter;
using MQTTnet.Core.Channel; using MQTTnet.Core.Channel;
using MQTTnet.Core.Client; using MQTTnet.Core.Client;
using MQTTnet.Core.Packets; using MQTTnet.Core.Packets;
@@ -436,20 +437,20 @@ namespace MQTTnet.Core.Tests
private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value) private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value)
{ {
var serializer = new MqttPacketSerializer(); var serializer = new MqttPacketSerializer();
var buffer1 = serializer.Serialize(packet); var buffer1 = serializer.Serialize(packet);


using (var headerStream = new MemoryStream( buffer1 ))
using (var headerStream = new MemoryStream(buffer1))
{ {
var header = MqttPacketReader.ReadHeaderFromSource( headerStream );
var header = MqttPacketReader.ReadHeaderFromSource(headerStream);


using (var bodyStream = new MemoryStream( buffer1, (int)headerStream.Position, header.BodyLength ))
using (var bodyStream = new MemoryStream(buffer1, (int)headerStream.Position, header.BodyLength))
{ {
var deserializedPacket = serializer.Deserialize(header, bodyStream);
var buffer2 = serializer.Serialize( deserializedPacket );
var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket(header, bodyStream));
var buffer2 = serializer.Serialize(deserializedPacket);


Assert.AreEqual( expectedBase64Value, Convert.ToBase64String( buffer2 ) );
}
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(buffer2));
}
} }
} }
} }


+ 26
- 27
Tests/MQTTnet.TestApp.NetFramework/PerformanceTest.cs View File

@@ -17,18 +17,18 @@ namespace MQTTnet.TestApp.NetFramework
{ {
public static async Task RunAsync() public static async Task RunAsync()
{ {
var server = Task.Run(() => RunServerAsync());
var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10)));
var server = Task.Factory.StartNew(RunServerAsync, TaskCreationOptions.LongRunning);
var client = Task.Factory.StartNew(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10)), TaskCreationOptions.LongRunning);


await Task.WhenAll(server, client).ConfigureAwait(false); await Task.WhenAll(server, client).ConfigureAwait(false);
} }


private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval) private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval)
{ {
return Task.WhenAll(Enumerable.Range(0, 3).Select((i) => Task.Run(() => RunClientAsync(msgChunkSize, interval))));
return Task.WhenAll(Enumerable.Range(0, 3).Select(i => Task.Run(() => RunClientAsync(msgChunkSize, interval))));
} }


private static async Task RunClientAsync( int msgChunkSize, TimeSpan interval )
private static async Task RunClientAsync(int msgChunkSize, TimeSpan interval)
{ {
try try
{ {
@@ -83,7 +83,7 @@ namespace MQTTnet.TestApp.NetFramework


Console.WriteLine("### WAITING FOR APPLICATION MESSAGES ###"); Console.WriteLine("### WAITING FOR APPLICATION MESSAGES ###");


var testMessageCount = 1000;
var testMessageCount = 10000;
var message = CreateMessage(); var message = CreateMessage();
var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
for (var i = 0; i < testMessageCount; i++) for (var i = 0; i < testMessageCount; i++)
@@ -92,8 +92,8 @@ namespace MQTTnet.TestApp.NetFramework
} }


stopwatch.Stop(); stopwatch.Stop();
Console.WriteLine($"Sent 1000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message).");
Console.WriteLine($"Sent 10.000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message).");
stopwatch.Restart(); stopwatch.Restart();
var sentMessagesCount = 0; var sentMessagesCount = 0;
while (stopwatch.ElapsedMilliseconds < 1000) while (stopwatch.ElapsedMilliseconds < 1000)
@@ -109,32 +109,32 @@ namespace MQTTnet.TestApp.NetFramework


while (true) while (true)
{ {
var msgs = Enumerable.Range( 0, msgChunkSize )
.Select( i => CreateMessage() )
var msgs = Enumerable.Range(0, msgChunkSize)
.Select(i => CreateMessage())
.ToList(); .ToList();


if (false) if (false)
{ {
//send concurrent (test for raceconditions) //send concurrent (test for raceconditions)
var sendTasks = msgs var sendTasks = msgs
.Select( msg => PublishSingleMessage( client, msg, ref msgCount ) )
.Select(msg => PublishSingleMessage(client, msg, ref msgCount))
.ToList(); .ToList();


await Task.WhenAll( sendTasks );
await Task.WhenAll(sendTasks);
} }
else else
{ {
await client.PublishAsync( msgs );
await client.PublishAsync(msgs);
msgCount += msgs.Count; msgCount += msgs.Count;
//send multiple //send multiple
} }




var now = DateTime.Now; var now = DateTime.Now;
if (last < now - TimeSpan.FromSeconds(1)) if (last < now - TimeSpan.FromSeconds(1))
{ {
Console.WriteLine( $"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}" );
Console.WriteLine($"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}");
msgCount = 0; msgCount = 0;
last = now; last = now;
} }
@@ -152,19 +152,19 @@ namespace MQTTnet.TestApp.NetFramework
{ {
return new MqttApplicationMessage( return new MqttApplicationMessage(
"A/B/C", "A/B/C",
Encoding.UTF8.GetBytes( "Hello World" ),
Encoding.UTF8.GetBytes("Hello World"),
MqttQualityOfServiceLevel.AtMostOnce, MqttQualityOfServiceLevel.AtMostOnce,
false false
); );
} }


private static Task PublishSingleMessage( IMqttClient client, MqttApplicationMessage applicationMessage, ref int count )
private static Task PublishSingleMessage(IMqttClient client, MqttApplicationMessage applicationMessage, ref int count)
{ {
Interlocked.Increment( ref count );
return Task.Run( () =>
{
return client.PublishAsync( applicationMessage );
} );
Interlocked.Increment(ref count);
return Task.Run(() =>
{
return client.PublishAsync(applicationMessage);
});
} }


private static void RunServerAsync() private static void RunServerAsync()
@@ -187,19 +187,18 @@ namespace MQTTnet.TestApp.NetFramework
}, },
DefaultCommunicationTimeout = TimeSpan.FromMinutes(10) DefaultCommunicationTimeout = TimeSpan.FromMinutes(10)
}; };
var mqttServer = new MqttServerFactory().CreateMqttServer(options); var mqttServer = new MqttServerFactory().CreateMqttServer(options);
var last = DateTime.Now;
var msgs = 0; var msgs = 0;
mqttServer.ApplicationMessageReceived += (sender, args) =>
var stopwatch = Stopwatch.StartNew();
mqttServer.ApplicationMessageReceived += (sender, args) =>
{ {
msgs++; msgs++;
var now = DateTime.Now;
if (last < now - TimeSpan.FromSeconds(1))
if (stopwatch.ElapsedMilliseconds > 1000)
{ {
Console.WriteLine($"received {msgs}"); Console.WriteLine($"received {msgs}");
msgs = 0; msgs = 0;
last = now;
stopwatch.Restart();
} }
}; };
mqttServer.Start(); mqttServer.Start();


+ 1
- 1
Tests/MQTTnet.TestApp.UniversalWindows/MQTTnet.TestApp.UniversalWindows.csproj View File

@@ -42,7 +42,7 @@
<UseVSHostingProcess>false</UseVSHostingProcess> <UseVSHostingProcess>false</UseVSHostingProcess>
<ErrorReport>prompt</ErrorReport> <ErrorReport>prompt</ErrorReport>
<Prefer32Bit>true</Prefer32Bit> <Prefer32Bit>true</Prefer32Bit>
<UseDotNetNativeToolchain>true</UseDotNetNativeToolchain>
<UseDotNetNativeToolchain>false</UseDotNetNativeToolchain>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|ARM'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|ARM'">
<DebugSymbols>true</DebugSymbols> <DebugSymbols>true</DebugSymbols>


Loading…
Cancel
Save