@@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using MQTTnet.Core.Channel; | using MQTTnet.Core.Channel; | ||||
using MQTTnet.Core.Client; | using MQTTnet.Core.Client; | ||||
using MQTTnet.Core.Exceptions; | |||||
using System.IO; | using System.IO; | ||||
namespace MQTTnet.Implementations | namespace MQTTnet.Implementations | ||||
@@ -16,10 +15,6 @@ namespace MQTTnet.Implementations | |||||
private Socket _socket; | private Socket _socket; | ||||
private SslStream _sslStream; | private SslStream _sslStream; | ||||
public Stream RawStream { get; private set; } | |||||
public Stream SendStream { get; private set; } | |||||
public Stream ReceiveStream { get; private set; } | |||||
/// <summary> | /// <summary> | ||||
/// called on client sockets are created in connect | /// called on client sockets are created in connect | ||||
/// </summary> | /// </summary> | ||||
@@ -36,61 +31,61 @@ namespace MQTTnet.Implementations | |||||
{ | { | ||||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | _socket = socket ?? throw new ArgumentNullException(nameof(socket)); | ||||
_sslStream = sslStream; | _sslStream = sslStream; | ||||
CreateCommStreams(socket, sslStream); | |||||
CreateStreams(socket, sslStream); | |||||
} | } | ||||
public Stream RawStream { get; private set; } | |||||
public Stream SendStream { get; private set; } | |||||
public Stream ReceiveStream { get; private set; } | |||||
public async Task ConnectAsync(MqttClientOptions options) | public async Task ConnectAsync(MqttClientOptions options) | ||||
{ | { | ||||
if (options == null) throw new ArgumentNullException(nameof(options)); | if (options == null) throw new ArgumentNullException(nameof(options)); | ||||
try | |||||
{ | |||||
if (_socket == null) | |||||
{ | |||||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||||
} | |||||
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false); | |||||
if (options.TlsOptions.UseTls) | |||||
{ | |||||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||||
if (_socket == null) | |||||
{ | |||||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||||
} | |||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||||
} | |||||
await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null).ConfigureAwait(false); | |||||
CreateCommStreams(_socket, _sslStream); | |||||
} | |||||
catch (SocketException exception) | |||||
if (options.TlsOptions.UseTls) | |||||
{ | { | ||||
throw new MqttCommunicationException(exception); | |||||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||||
} | } | ||||
CreateStreams(_socket, _sslStream); | |||||
} | } | ||||
public Task DisconnectAsync() | public Task DisconnectAsync() | ||||
{ | { | ||||
try | |||||
{ | |||||
Dispose(); | |||||
return Task.FromResult(0); | |||||
} | |||||
catch (SocketException exception) | |||||
{ | |||||
throw new MqttCommunicationException(exception); | |||||
} | |||||
Dispose(); | |||||
return Task.FromResult(0); | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
_socket?.Dispose(); | |||||
_sslStream?.Dispose(); | |||||
RawStream?.Dispose(); | |||||
RawStream = null; | |||||
ReceiveStream?.Dispose(); | |||||
ReceiveStream = null; | |||||
SendStream?.Dispose(); | |||||
SendStream = null; | |||||
_socket?.Dispose(); | |||||
_socket = null; | _socket = null; | ||||
_sslStream?.Dispose(); | |||||
_sslStream = null; | _sslStream = null; | ||||
} | } | ||||
private void CreateCommStreams(Socket socket, SslStream sslStream) | |||||
private void CreateStreams(Socket socket, Stream sslStream) | |||||
{ | { | ||||
RawStream = (Stream)sslStream ?? new NetworkStream(socket); | |||||
RawStream = sslStream ?? new NetworkStream(socket); | |||||
//cannot use this as default buffering prevents from receiving the first connect message | //cannot use this as default buffering prevents from receiving the first connect message | ||||
//need two streams otherwise read and write have to be synchronized | //need two streams otherwise read and write have to be synchronized | ||||
@@ -14,7 +14,6 @@ namespace MQTTnet.Implementations | |||||
private ClientWebSocket _webSocket = new ClientWebSocket(); | private ClientWebSocket _webSocket = new ClientWebSocket(); | ||||
public Stream RawStream { get; private set; } | public Stream RawStream { get; private set; } | ||||
public Stream SendStream => RawStream; | public Stream SendStream => RawStream; | ||||
public Stream ReceiveStream => RawStream; | public Stream ReceiveStream => RawStream; | ||||
@@ -56,15 +56,12 @@ namespace MQTTnet.Implementations | |||||
public override bool CanSeek => false; | public override bool CanSeek => false; | ||||
public override bool CanWrite => true; | public override bool CanWrite => true; | ||||
public override long Length | |||||
{ | |||||
get { throw new NotSupportedException(); } | |||||
} | |||||
public override long Length => throw new NotSupportedException(); | |||||
public override long Position | public override long Position | ||||
{ | { | ||||
get { throw new NotSupportedException(); } | |||||
set { throw new NotSupportedException(); } | |||||
get => throw new NotSupportedException(); | |||||
set => throw new NotSupportedException(); | |||||
} | } | ||||
public override long Seek(long offset, SeekOrigin origin) | public override long Seek(long offset, SeekOrigin origin) | ||||
@@ -6,7 +6,6 @@ using System.Security.Cryptography.X509Certificates; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using MQTTnet.Core.Channel; | using MQTTnet.Core.Channel; | ||||
using MQTTnet.Core.Client; | using MQTTnet.Core.Client; | ||||
using MQTTnet.Core.Exceptions; | |||||
using System.IO; | using System.IO; | ||||
namespace MQTTnet.Implementations | namespace MQTTnet.Implementations | ||||
@@ -16,10 +15,6 @@ namespace MQTTnet.Implementations | |||||
private Socket _socket; | private Socket _socket; | ||||
private SslStream _sslStream; | private SslStream _sslStream; | ||||
public Stream ReceiveStream { get; private set; } | |||||
public Stream RawStream => ReceiveStream; | |||||
public Stream SendStream => ReceiveStream; | |||||
/// <summary> | /// <summary> | ||||
/// called on client sockets are created in connect | /// called on client sockets are created in connect | ||||
/// </summary> | /// </summary> | ||||
@@ -38,55 +33,45 @@ namespace MQTTnet.Implementations | |||||
ReceiveStream = (Stream)sslStream ?? new NetworkStream(socket); | ReceiveStream = (Stream)sslStream ?? new NetworkStream(socket); | ||||
} | } | ||||
public Stream ReceiveStream { get; private set; } | |||||
public Stream RawStream => ReceiveStream; | |||||
public Stream SendStream => ReceiveStream; | |||||
public async Task ConnectAsync(MqttClientOptions options) | public async Task ConnectAsync(MqttClientOptions options) | ||||
{ | { | ||||
if (options == null) throw new ArgumentNullException(nameof(options)); | if (options == null) throw new ArgumentNullException(nameof(options)); | ||||
try | |||||
if (_socket == null) | |||||
{ | { | ||||
if (_socket == null) | |||||
{ | |||||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||||
} | |||||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||||
} | |||||
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); | |||||
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); | |||||
if (options.TlsOptions.UseTls) | |||||
{ | |||||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||||
ReceiveStream = _sslStream; | |||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||||
} | |||||
else | |||||
{ | |||||
ReceiveStream = new NetworkStream(_socket); | |||||
} | |||||
if (options.TlsOptions.UseTls) | |||||
{ | |||||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||||
ReceiveStream = _sslStream; | |||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||||
} | } | ||||
catch (SocketException exception) | |||||
else | |||||
{ | { | ||||
throw new MqttCommunicationException(exception); | |||||
ReceiveStream = new NetworkStream(_socket); | |||||
} | } | ||||
} | } | ||||
public Task DisconnectAsync() | public Task DisconnectAsync() | ||||
{ | { | ||||
try | |||||
{ | |||||
Dispose(); | |||||
return Task.FromResult(0); | |||||
} | |||||
catch (SocketException exception) | |||||
{ | |||||
throw new MqttCommunicationException(exception); | |||||
} | |||||
Dispose(); | |||||
return Task.FromResult(0); | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
_socket?.Dispose(); | _socket?.Dispose(); | ||||
_sslStream?.Dispose(); | |||||
_socket = null; | _socket = null; | ||||
_sslStream?.Dispose(); | |||||
_sslStream = null; | _sslStream = null; | ||||
} | } | ||||
@@ -1,15 +1,13 @@ | |||||
using System; | using System; | ||||
using System.IO; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Net.Sockets; | |||||
using System.Runtime.InteropServices.WindowsRuntime; | using System.Runtime.InteropServices.WindowsRuntime; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using Windows.Networking; | using Windows.Networking; | ||||
using Windows.Networking.Sockets; | using Windows.Networking.Sockets; | ||||
using Windows.Security.Cryptography.Certificates; | using Windows.Security.Cryptography.Certificates; | ||||
using Windows.Storage.Streams; | |||||
using MQTTnet.Core.Channel; | using MQTTnet.Core.Channel; | ||||
using MQTTnet.Core.Client; | using MQTTnet.Core.Client; | ||||
using MQTTnet.Core.Exceptions; | |||||
namespace MQTTnet.Implementations | namespace MQTTnet.Implementations | ||||
{ | { | ||||
@@ -26,89 +24,58 @@ namespace MQTTnet.Implementations | |||||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | _socket = socket ?? throw new ArgumentNullException(nameof(socket)); | ||||
} | } | ||||
public Stream SendStream { get; private set; } | |||||
public Stream ReceiveStream { get; private set; } | |||||
public Stream RawStream { get; private set; } | |||||
public async Task ConnectAsync(MqttClientOptions options) | public async Task ConnectAsync(MqttClientOptions options) | ||||
{ | { | ||||
if (options == null) throw new ArgumentNullException(nameof(options)); | if (options == null) throw new ArgumentNullException(nameof(options)); | ||||
try | |||||
{ | |||||
if (_socket == null) | |||||
{ | |||||
_socket = new StreamSocket(); | |||||
} | |||||
if (!options.TlsOptions.UseTls) | |||||
{ | |||||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString()); | |||||
} | |||||
else | |||||
{ | |||||
_socket.Control.ClientCertificate = LoadCertificate(options); | |||||
if (!options.TlsOptions.CheckCertificateRevocation) | |||||
{ | |||||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain); | |||||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing); | |||||
} | |||||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12); | |||||
} | |||||
} | |||||
catch (SocketException exception) | |||||
if (_socket == null) | |||||
{ | { | ||||
throw new MqttCommunicationException(exception); | |||||
_socket = new StreamSocket(); | |||||
} | } | ||||
} | |||||
public Task DisconnectAsync() | |||||
{ | |||||
try | |||||
if (!options.TlsOptions.UseTls) | |||||
{ | { | ||||
Dispose(); | |||||
return Task.FromResult(0); | |||||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString()); | |||||
} | } | ||||
catch (SocketException exception) | |||||
else | |||||
{ | { | ||||
throw new MqttCommunicationException(exception); | |||||
} | |||||
} | |||||
_socket.Control.ClientCertificate = LoadCertificate(options); | |||||
public async Task WriteAsync(byte[] buffer) | |||||
{ | |||||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | |||||
if (!options.TlsOptions.CheckCertificateRevocation) | |||||
{ | |||||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.IncompleteChain); | |||||
_socket.Control.IgnorableServerCertificateErrors.Add(ChainValidationResult.RevocationInformationMissing); | |||||
} | |||||
try | |||||
{ | |||||
await _socket.OutputStream.WriteAsync(buffer.AsBuffer()); | |||||
await _socket.OutputStream.FlushAsync(); | |||||
} | |||||
catch (SocketException exception) | |||||
{ | |||||
throw new MqttCommunicationException(exception); | |||||
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString(), SocketProtectionLevel.Tls12); | |||||
} | } | ||||
} | |||||
public int Peek() | |||||
{ | |||||
return 0; | |||||
ReceiveStream = _socket.InputStream.AsStreamForRead(); | |||||
SendStream = _socket.OutputStream.AsStreamForWrite(); | |||||
RawStream = ReceiveStream; | |||||
} | } | ||||
public async Task<ArraySegment<byte>> ReadAsync(int length, byte[] buffer) | |||||
public Task DisconnectAsync() | |||||
{ | { | ||||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | |||||
try | |||||
{ | |||||
var result = await _socket.InputStream.ReadAsync(buffer.AsBuffer(), (uint)buffer.Length, InputStreamOptions.None); | |||||
return new ArraySegment<byte>(buffer, 0, (int)result.Length); | |||||
} | |||||
catch (SocketException exception) | |||||
{ | |||||
throw new MqttCommunicationException(exception); | |||||
} | |||||
Dispose(); | |||||
return Task.FromResult(0); | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
RawStream?.Dispose(); | |||||
RawStream = null; | |||||
SendStream?.Dispose(); | |||||
SendStream = null; | |||||
ReceiveStream?.Dispose(); | |||||
ReceiveStream = null; | |||||
_socket?.Dispose(); | _socket?.Dispose(); | ||||
_socket = null; | _socket = null; | ||||
} | } | ||||
@@ -14,7 +14,6 @@ namespace MQTTnet.Implementations | |||||
private ClientWebSocket _webSocket = new ClientWebSocket(); | private ClientWebSocket _webSocket = new ClientWebSocket(); | ||||
public Stream RawStream { get; private set; } | public Stream RawStream { get; private set; } | ||||
public Stream SendStream => RawStream; | public Stream SendStream => RawStream; | ||||
public Stream ReceiveStream => RawStream; | public Stream ReceiveStream => RawStream; | ||||
@@ -130,6 +130,11 @@ | |||||
<Version>5.3.3</Version> | <Version>5.3.3</Version> | ||||
</PackageReference> | </PackageReference> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<Reference Include="System.Net.Security"> | |||||
<HintPath>..\..\..\..\Program Files\dotnet\sdk\NuGetFallbackFolder\microsoft.netcore.app\2.0.0\ref\netcoreapp2.0\System.Net.Security.dll</HintPath> | |||||
</Reference> | |||||
</ItemGroup> | |||||
<PropertyGroup Condition=" '$(VisualStudioVersion)' == '' or '$(VisualStudioVersion)' < '14.0' "> | <PropertyGroup Condition=" '$(VisualStudioVersion)' == '' or '$(VisualStudioVersion)' < '14.0' "> | ||||
<VisualStudioVersion>14.0</VisualStudioVersion> | <VisualStudioVersion>14.0</VisualStudioVersion> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -15,7 +15,6 @@ namespace MQTTnet.Core.Adapter | |||||
public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter | public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter | ||||
{ | { | ||||
private readonly IMqttCommunicationChannel _channel; | private readonly IMqttCommunicationChannel _channel; | ||||
private readonly byte[] _readBuffer = new byte[BufferConstants.Size]; | |||||
private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write | private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write | ||||
@@ -29,76 +28,105 @@ namespace MQTTnet.Core.Adapter | |||||
public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) | public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) | ||||
{ | { | ||||
return _channel.ConnectAsync(options).TimeoutAfter(timeout); | |||||
try | |||||
{ | |||||
return _channel.ConnectAsync(options).TimeoutAfter(timeout); | |||||
} | |||||
catch (Exception exception) | |||||
{ | |||||
throw new MqttCommunicationException(exception); | |||||
} | |||||
} | } | ||||
public Task DisconnectAsync() | public Task DisconnectAsync() | ||||
{ | { | ||||
return _channel.DisconnectAsync(); | |||||
try | |||||
{ | |||||
return _channel.DisconnectAsync(); | |||||
} | |||||
catch (Exception exception) | |||||
{ | |||||
throw new MqttCommunicationException(exception); | |||||
} | |||||
} | } | ||||
public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets) | public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable<MqttBasePacket> packets) | ||||
{ | { | ||||
lock (_channel) | |||||
try | |||||
{ | { | ||||
foreach (var packet in packets) | |||||
lock (_channel) | |||||
{ | { | ||||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); | |||||
foreach (var packet in packets) | |||||
{ | |||||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); | |||||
var writeBuffer = PacketSerializer.Serialize(packet); | |||||
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); | |||||
var writeBuffer = PacketSerializer.Serialize(packet); | |||||
_sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); | |||||
} | |||||
} | } | ||||
} | |||||
await _sendTask; // configure await false geneates stackoverflow | |||||
await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); | |||||
await _sendTask; // configure await false geneates stackoverflow | |||||
await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); | |||||
} | |||||
catch (Exception exception) | |||||
{ | |||||
throw new MqttCommunicationException(exception); | |||||
} | |||||
} | } | ||||
public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout) | public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout) | ||||
{ | { | ||||
Tuple<MqttPacketHeader, MemoryStream> tuple; | |||||
if (timeout > TimeSpan.Zero) | |||||
{ | |||||
tuple = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); | |||||
} | |||||
else | |||||
try | |||||
{ | { | ||||
tuple = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); | |||||
} | |||||
ReceivedMqttPacket receivedMqttPacket; | |||||
if (timeout > TimeSpan.Zero) | |||||
{ | |||||
receivedMqttPacket = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); | |||||
} | |||||
else | |||||
{ | |||||
receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); | |||||
} | |||||
var packet = PacketSerializer.Deserialize(tuple.Item1, tuple.Item2); | |||||
var packet = PacketSerializer.Deserialize(receivedMqttPacket); | |||||
if (packet == null) | |||||
{ | |||||
throw new MqttProtocolViolationException("Received malformed packet."); | |||||
} | |||||
if (packet == null) | |||||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); | |||||
return packet; | |||||
} | |||||
catch (Exception exception) | |||||
{ | { | ||||
throw new MqttProtocolViolationException("Received malformed packet."); | |||||
throw new MqttCommunicationException(exception); | |||||
} | } | ||||
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); | |||||
return packet; | |||||
} | } | ||||
private async Task<Tuple<MqttPacketHeader, MemoryStream>> ReceiveAsync(Stream stream) | |||||
private async Task<ReceivedMqttPacket> ReceiveAsync(Stream stream) | |||||
{ | { | ||||
var header = MqttPacketReader.ReadHeaderFromSource(stream); | var header = MqttPacketReader.ReadHeaderFromSource(stream); | ||||
MemoryStream body; | |||||
if (header.BodyLength > 0) | |||||
if (header.BodyLength == 0) | |||||
{ | { | ||||
var totalRead = 0; | |||||
do | |||||
{ | |||||
var read = await stream.ReadAsync(_readBuffer, totalRead, header.BodyLength - totalRead).ConfigureAwait(false); | |||||
totalRead += read; | |||||
} while (totalRead < header.BodyLength); | |||||
body = new MemoryStream(_readBuffer, 0, header.BodyLength); | |||||
return new ReceivedMqttPacket(header, new MemoryStream(0)); | |||||
} | } | ||||
else | |||||
var body = new byte[header.BodyLength]; | |||||
var offset = 0; | |||||
do | |||||
{ | |||||
var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset).ConfigureAwait(false); | |||||
offset += readBytesCount; | |||||
} while (offset < header.BodyLength); | |||||
if (offset > header.BodyLength) | |||||
{ | { | ||||
body = new MemoryStream(); | |||||
throw new MqttCommunicationException($"Read more body bytes than required ({offset}/{header.BodyLength})."); | |||||
} | } | ||||
return Tuple.Create(header, body); | |||||
return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length)); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,19 @@ | |||||
using System; | |||||
using System.IO; | |||||
using MQTTnet.Core.Packets; | |||||
namespace MQTTnet.Core.Adapter | |||||
{ | |||||
public class ReceivedMqttPacket | |||||
{ | |||||
public ReceivedMqttPacket(MqttPacketHeader header, MemoryStream body) | |||||
{ | |||||
Header = header ?? throw new ArgumentNullException(nameof(header)); | |||||
Body = body ?? throw new ArgumentNullException(nameof(body)); | |||||
} | |||||
public MqttPacketHeader Header { get; } | |||||
public MemoryStream Body { get; } | |||||
} | |||||
} |
@@ -99,6 +99,11 @@ namespace MQTTnet.Core.Client | |||||
public async Task DisconnectAsync() | public async Task DisconnectAsync() | ||||
{ | { | ||||
if (!IsConnected) | |||||
{ | |||||
return; | |||||
} | |||||
try | try | ||||
{ | { | ||||
await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); | await SendAsync(new MqttDisconnectPacket()).ConfigureAwait(false); | ||||
@@ -11,7 +11,7 @@ namespace MQTTnet.Core.Client | |||||
public class MqttPacketDispatcher | public class MqttPacketDispatcher | ||||
{ | { | ||||
private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>(); | private readonly ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>> _packetByResponseType = new ConcurrentDictionary<Type, TaskCompletionSource<MqttBasePacket>>(); | ||||
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort,TaskCompletionSource<MqttBasePacket>>>(); | |||||
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>> _packetByResponseTypeAndIdentifier = new ConcurrentDictionary<Type, ConcurrentDictionary<ushort, TaskCompletionSource<MqttBasePacket>>>(); | |||||
public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) | public async Task<MqttBasePacket> WaitForPacketAsync(MqttBasePacket request, Type responseType, TimeSpan timeout) | ||||
{ | { | ||||
@@ -24,7 +24,7 @@ namespace MQTTnet.Core.Client | |||||
} | } | ||||
catch (MqttCommunicationTimedOutException) | catch (MqttCommunicationTimedOutException) | ||||
{ | { | ||||
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet."); | |||||
MqttTrace.Warning(nameof(MqttPacketDispatcher), "Timeout while waiting for packet of type '{0}'.", responseType.Name); | |||||
throw; | throw; | ||||
} | } | ||||
finally | finally | ||||
@@ -42,16 +42,20 @@ namespace MQTTnet.Core.Client | |||||
{ | { | ||||
if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) | if (_packetByResponseTypeAndIdentifier.TryGetValue(type, out var byid)) | ||||
{ | { | ||||
if (byid.TryRemove( withIdentifier.PacketIdentifier, out var tcs)) | |||||
if (byid.TryRemove(withIdentifier.PacketIdentifier, out var tcs)) | |||||
{ | { | ||||
tcs.TrySetResult( packet ); | |||||
tcs.TrySetResult(packet); | |||||
return; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
else if (_packetByResponseType.TryRemove(type, out var tcs)) | else if (_packetByResponseType.TryRemove(type, out var tcs)) | ||||
{ | { | ||||
tcs.TrySetResult(packet); | tcs.TrySetResult(packet); | ||||
return; | |||||
} | } | ||||
throw new InvalidOperationException($"Packet of type '{type.Name}' not handled or dispatched."); | |||||
} | } | ||||
public void Reset() | public void Reset() | ||||
@@ -25,20 +25,22 @@ namespace MQTTnet.Core.Internal | |||||
try | try | ||||
{ | { | ||||
cancellationTokenSource.CancelAfter(timeout); | |||||
#pragma warning disable 4014 | |||||
task.ContinueWith(t => | task.ContinueWith(t => | ||||
{ | |||||
if (t.IsFaulted) | |||||
{ | |||||
tcs.TrySetException(t.Exception); | |||||
} | |||||
#pragma warning restore 4014 | |||||
{ | |||||
if (t.IsFaulted) | |||||
{ | |||||
tcs.TrySetException(t.Exception); | |||||
} | |||||
if (t.IsCompleted) | |||||
{ | |||||
tcs.TrySetResult(t.Result); | |||||
} | |||||
}, cancellationTokenSource.Token); | |||||
if (t.IsCompleted) | |||||
{ | |||||
tcs.TrySetResult(t.Result); | |||||
} | |||||
}, cancellationTokenSource.Token); | |||||
cancellationTokenSource.CancelAfter(timeout); | |||||
return await tcs.Task; | return await tcs.Task; | ||||
} | } | ||||
catch (TaskCanceledException) | catch (TaskCanceledException) | ||||
@@ -1,4 +1,4 @@ | |||||
using System.IO; | |||||
using MQTTnet.Core.Adapter; | |||||
using MQTTnet.Core.Packets; | using MQTTnet.Core.Packets; | ||||
namespace MQTTnet.Core.Serializer | namespace MQTTnet.Core.Serializer | ||||
@@ -9,6 +9,6 @@ namespace MQTTnet.Core.Serializer | |||||
byte[] Serialize(MqttBasePacket mqttPacket); | byte[] Serialize(MqttBasePacket mqttPacket); | ||||
MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream stream); | |||||
MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket); | |||||
} | } | ||||
} | } |
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.IO; | using System.IO; | ||||
using System.Text; | using System.Text; | ||||
using MQTTnet.Core.Adapter; | |||||
using MQTTnet.Core.Exceptions; | using MQTTnet.Core.Exceptions; | ||||
using MQTTnet.Core.Protocol; | using MQTTnet.Core.Protocol; | ||||
using MQTTnet.Core.Packets; | using MQTTnet.Core.Packets; | ||||
@@ -9,15 +10,29 @@ namespace MQTTnet.Core.Serializer | |||||
{ | { | ||||
public sealed class MqttPacketReader : BinaryReader | public sealed class MqttPacketReader : BinaryReader | ||||
{ | { | ||||
private readonly MqttPacketHeader _header; | |||||
public MqttPacketReader(Stream stream, MqttPacketHeader header) | |||||
: base(stream, Encoding.UTF8, true) | |||||
private readonly ReceivedMqttPacket _receivedMqttPacket; | |||||
public MqttPacketReader(ReceivedMqttPacket receivedMqttPacket) | |||||
: base(receivedMqttPacket.Body, Encoding.UTF8, true) | |||||
{ | { | ||||
_header = header; | |||||
_receivedMqttPacket = receivedMqttPacket; | |||||
} | } | ||||
public bool EndOfRemainingData => BaseStream.Position == _header.BodyLength; | |||||
public bool EndOfRemainingData => BaseStream.Position == _receivedMqttPacket.Header.BodyLength; | |||||
public static MqttPacketHeader ReadHeaderFromSource(Stream stream) | |||||
{ | |||||
var fixedHeader = (byte)stream.ReadByte(); | |||||
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); | |||||
var bodyLength = ReadBodyLengthFromSource(stream); | |||||
return new MqttPacketHeader | |||||
{ | |||||
FixedHeader = fixedHeader, | |||||
ControlPacketType = controlPacketType, | |||||
BodyLength = bodyLength | |||||
}; | |||||
} | |||||
public override ushort ReadUInt16() | public override ushort ReadUInt16() | ||||
{ | { | ||||
@@ -44,21 +59,7 @@ namespace MQTTnet.Core.Serializer | |||||
public byte[] ReadRemainingData() | public byte[] ReadRemainingData() | ||||
{ | { | ||||
return ReadBytes(_header.BodyLength - (int)BaseStream.Position); | |||||
} | |||||
public static MqttPacketHeader ReadHeaderFromSource(Stream stream) | |||||
{ | |||||
var fixedHeader = (byte)stream.ReadByte(); | |||||
var controlPacketType = (MqttControlPacketType)(fixedHeader >> 4); | |||||
var bodyLength = ReadBodyLengthFromSource(stream); | |||||
return new MqttPacketHeader | |||||
{ | |||||
FixedHeader = fixedHeader, | |||||
ControlPacketType = controlPacketType, | |||||
BodyLength = bodyLength | |||||
}; | |||||
return ReadBytes(_receivedMqttPacket.Header.BodyLength - (int)BaseStream.Position); | |||||
} | } | ||||
private static int ReadBodyLengthFromSource(Stream stream) | private static int ReadBodyLengthFromSource(Stream stream) | ||||
@@ -74,7 +75,7 @@ namespace MQTTnet.Core.Serializer | |||||
multiplier *= 128; | multiplier *= 128; | ||||
if (multiplier > 128 * 128 * 128) | if (multiplier > 128 * 128 * 128) | ||||
{ | { | ||||
throw new MqttProtocolViolationException("Remaining length is ivalid."); | |||||
throw new MqttProtocolViolationException("Remaining length is invalid."); | |||||
} | } | ||||
} while ((encodedByte & 128) != 0); | } while ((encodedByte & 128) != 0); | ||||
return value; | return value; | ||||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using MQTTnet.Core.Adapter; | |||||
using MQTTnet.Core.Exceptions; | using MQTTnet.Core.Exceptions; | ||||
using MQTTnet.Core.Packets; | using MQTTnet.Core.Packets; | ||||
using MQTTnet.Core.Protocol; | using MQTTnet.Core.Protocol; | ||||
@@ -110,14 +111,13 @@ namespace MQTTnet.Core.Serializer | |||||
throw new MqttProtocolViolationException("Packet type invalid."); | throw new MqttProtocolViolationException("Packet type invalid."); | ||||
} | } | ||||
public MqttBasePacket Deserialize(MqttPacketHeader header, MemoryStream body) | |||||
public MqttBasePacket Deserialize(ReceivedMqttPacket receivedMqttPacket) | |||||
{ | { | ||||
if (header == null) throw new ArgumentNullException(nameof(header)); | |||||
if (body == null) throw new ArgumentNullException(nameof(body)); | |||||
if (receivedMqttPacket == null) throw new ArgumentNullException(nameof(receivedMqttPacket)); | |||||
using (var reader = new MqttPacketReader(body, header)) | |||||
using (var reader = new MqttPacketReader(receivedMqttPacket)) | |||||
{ | { | ||||
return Deserialize(header, reader); | |||||
return Deserialize(receivedMqttPacket.Header, reader); | |||||
} | } | ||||
} | } | ||||
@@ -3,6 +3,7 @@ using System.IO; | |||||
using System.Text; | using System.Text; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using MQTTnet.Core.Adapter; | |||||
using MQTTnet.Core.Channel; | using MQTTnet.Core.Channel; | ||||
using MQTTnet.Core.Client; | using MQTTnet.Core.Client; | ||||
using MQTTnet.Core.Packets; | using MQTTnet.Core.Packets; | ||||
@@ -436,20 +437,20 @@ namespace MQTTnet.Core.Tests | |||||
private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value) | private static void DeserializeAndCompare(MqttBasePacket packet, string expectedBase64Value) | ||||
{ | { | ||||
var serializer = new MqttPacketSerializer(); | var serializer = new MqttPacketSerializer(); | ||||
var buffer1 = serializer.Serialize(packet); | var buffer1 = serializer.Serialize(packet); | ||||
using (var headerStream = new MemoryStream( buffer1 )) | |||||
using (var headerStream = new MemoryStream(buffer1)) | |||||
{ | { | ||||
var header = MqttPacketReader.ReadHeaderFromSource( headerStream ); | |||||
var header = MqttPacketReader.ReadHeaderFromSource(headerStream); | |||||
using (var bodyStream = new MemoryStream( buffer1, (int)headerStream.Position, header.BodyLength )) | |||||
using (var bodyStream = new MemoryStream(buffer1, (int)headerStream.Position, header.BodyLength)) | |||||
{ | { | ||||
var deserializedPacket = serializer.Deserialize(header, bodyStream); | |||||
var buffer2 = serializer.Serialize( deserializedPacket ); | |||||
var deserializedPacket = serializer.Deserialize(new ReceivedMqttPacket(header, bodyStream)); | |||||
var buffer2 = serializer.Serialize(deserializedPacket); | |||||
Assert.AreEqual( expectedBase64Value, Convert.ToBase64String( buffer2 ) ); | |||||
} | |||||
Assert.AreEqual(expectedBase64Value, Convert.ToBase64String(buffer2)); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -17,18 +17,18 @@ namespace MQTTnet.TestApp.NetFramework | |||||
{ | { | ||||
public static async Task RunAsync() | public static async Task RunAsync() | ||||
{ | { | ||||
var server = Task.Run(() => RunServerAsync()); | |||||
var client = Task.Run(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10))); | |||||
var server = Task.Factory.StartNew(RunServerAsync, TaskCreationOptions.LongRunning); | |||||
var client = Task.Factory.StartNew(() => RunClientAsync(2000, TimeSpan.FromMilliseconds(10)), TaskCreationOptions.LongRunning); | |||||
await Task.WhenAll(server, client).ConfigureAwait(false); | await Task.WhenAll(server, client).ConfigureAwait(false); | ||||
} | } | ||||
private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval) | private static Task RunClientsAsync(int msgChunkSize, TimeSpan interval) | ||||
{ | { | ||||
return Task.WhenAll(Enumerable.Range(0, 3).Select((i) => Task.Run(() => RunClientAsync(msgChunkSize, interval)))); | |||||
return Task.WhenAll(Enumerable.Range(0, 3).Select(i => Task.Run(() => RunClientAsync(msgChunkSize, interval)))); | |||||
} | } | ||||
private static async Task RunClientAsync( int msgChunkSize, TimeSpan interval ) | |||||
private static async Task RunClientAsync(int msgChunkSize, TimeSpan interval) | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
@@ -83,7 +83,7 @@ namespace MQTTnet.TestApp.NetFramework | |||||
Console.WriteLine("### WAITING FOR APPLICATION MESSAGES ###"); | Console.WriteLine("### WAITING FOR APPLICATION MESSAGES ###"); | ||||
var testMessageCount = 1000; | |||||
var testMessageCount = 10000; | |||||
var message = CreateMessage(); | var message = CreateMessage(); | ||||
var stopwatch = Stopwatch.StartNew(); | var stopwatch = Stopwatch.StartNew(); | ||||
for (var i = 0; i < testMessageCount; i++) | for (var i = 0; i < testMessageCount; i++) | ||||
@@ -92,8 +92,8 @@ namespace MQTTnet.TestApp.NetFramework | |||||
} | } | ||||
stopwatch.Stop(); | stopwatch.Stop(); | ||||
Console.WriteLine($"Sent 1000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message)."); | |||||
Console.WriteLine($"Sent 10.000 messages within {stopwatch.ElapsedMilliseconds} ms ({stopwatch.ElapsedMilliseconds / (float)testMessageCount} ms / message)."); | |||||
stopwatch.Restart(); | stopwatch.Restart(); | ||||
var sentMessagesCount = 0; | var sentMessagesCount = 0; | ||||
while (stopwatch.ElapsedMilliseconds < 1000) | while (stopwatch.ElapsedMilliseconds < 1000) | ||||
@@ -109,32 +109,32 @@ namespace MQTTnet.TestApp.NetFramework | |||||
while (true) | while (true) | ||||
{ | { | ||||
var msgs = Enumerable.Range( 0, msgChunkSize ) | |||||
.Select( i => CreateMessage() ) | |||||
var msgs = Enumerable.Range(0, msgChunkSize) | |||||
.Select(i => CreateMessage()) | |||||
.ToList(); | .ToList(); | ||||
if (false) | if (false) | ||||
{ | { | ||||
//send concurrent (test for raceconditions) | //send concurrent (test for raceconditions) | ||||
var sendTasks = msgs | var sendTasks = msgs | ||||
.Select( msg => PublishSingleMessage( client, msg, ref msgCount ) ) | |||||
.Select(msg => PublishSingleMessage(client, msg, ref msgCount)) | |||||
.ToList(); | .ToList(); | ||||
await Task.WhenAll( sendTasks ); | |||||
await Task.WhenAll(sendTasks); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
await client.PublishAsync( msgs ); | |||||
await client.PublishAsync(msgs); | |||||
msgCount += msgs.Count; | msgCount += msgs.Count; | ||||
//send multiple | //send multiple | ||||
} | } | ||||
var now = DateTime.Now; | var now = DateTime.Now; | ||||
if (last < now - TimeSpan.FromSeconds(1)) | if (last < now - TimeSpan.FromSeconds(1)) | ||||
{ | { | ||||
Console.WriteLine( $"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}" ); | |||||
Console.WriteLine($"sending {msgCount} inteded {msgChunkSize / interval.TotalSeconds}"); | |||||
msgCount = 0; | msgCount = 0; | ||||
last = now; | last = now; | ||||
} | } | ||||
@@ -152,19 +152,19 @@ namespace MQTTnet.TestApp.NetFramework | |||||
{ | { | ||||
return new MqttApplicationMessage( | return new MqttApplicationMessage( | ||||
"A/B/C", | "A/B/C", | ||||
Encoding.UTF8.GetBytes( "Hello World" ), | |||||
Encoding.UTF8.GetBytes("Hello World"), | |||||
MqttQualityOfServiceLevel.AtMostOnce, | MqttQualityOfServiceLevel.AtMostOnce, | ||||
false | false | ||||
); | ); | ||||
} | } | ||||
private static Task PublishSingleMessage( IMqttClient client, MqttApplicationMessage applicationMessage, ref int count ) | |||||
private static Task PublishSingleMessage(IMqttClient client, MqttApplicationMessage applicationMessage, ref int count) | |||||
{ | { | ||||
Interlocked.Increment( ref count ); | |||||
return Task.Run( () => | |||||
{ | |||||
return client.PublishAsync( applicationMessage ); | |||||
} ); | |||||
Interlocked.Increment(ref count); | |||||
return Task.Run(() => | |||||
{ | |||||
return client.PublishAsync(applicationMessage); | |||||
}); | |||||
} | } | ||||
private static void RunServerAsync() | private static void RunServerAsync() | ||||
@@ -187,19 +187,18 @@ namespace MQTTnet.TestApp.NetFramework | |||||
}, | }, | ||||
DefaultCommunicationTimeout = TimeSpan.FromMinutes(10) | DefaultCommunicationTimeout = TimeSpan.FromMinutes(10) | ||||
}; | }; | ||||
var mqttServer = new MqttServerFactory().CreateMqttServer(options); | var mqttServer = new MqttServerFactory().CreateMqttServer(options); | ||||
var last = DateTime.Now; | |||||
var msgs = 0; | var msgs = 0; | ||||
mqttServer.ApplicationMessageReceived += (sender, args) => | |||||
var stopwatch = Stopwatch.StartNew(); | |||||
mqttServer.ApplicationMessageReceived += (sender, args) => | |||||
{ | { | ||||
msgs++; | msgs++; | ||||
var now = DateTime.Now; | |||||
if (last < now - TimeSpan.FromSeconds(1)) | |||||
if (stopwatch.ElapsedMilliseconds > 1000) | |||||
{ | { | ||||
Console.WriteLine($"received {msgs}"); | Console.WriteLine($"received {msgs}"); | ||||
msgs = 0; | msgs = 0; | ||||
last = now; | |||||
stopwatch.Restart(); | |||||
} | } | ||||
}; | }; | ||||
mqttServer.Start(); | mqttServer.Start(); | ||||
@@ -42,7 +42,7 @@ | |||||
<UseVSHostingProcess>false</UseVSHostingProcess> | <UseVSHostingProcess>false</UseVSHostingProcess> | ||||
<ErrorReport>prompt</ErrorReport> | <ErrorReport>prompt</ErrorReport> | ||||
<Prefer32Bit>true</Prefer32Bit> | <Prefer32Bit>true</Prefer32Bit> | ||||
<UseDotNetNativeToolchain>true</UseDotNetNativeToolchain> | |||||
<UseDotNetNativeToolchain>false</UseDotNetNativeToolchain> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|ARM'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|ARM'"> | ||||
<DebugSymbols>true</DebugSymbols> | <DebugSymbols>true</DebugSymbols> | ||||