fixed bug in read async and unfied stream handlingrelease/3.x.x
@@ -7,22 +7,33 @@ using System.Threading.Tasks; | |||||
using MQTTnet.Core.Channel; | using MQTTnet.Core.Channel; | ||||
using MQTTnet.Core.Client; | using MQTTnet.Core.Client; | ||||
using MQTTnet.Core.Exceptions; | using MQTTnet.Core.Exceptions; | ||||
using System.IO; | |||||
namespace MQTTnet.Implementations | namespace MQTTnet.Implementations | ||||
{ | { | ||||
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable | public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable | ||||
{ | { | ||||
private Stream _dataStream; | |||||
private Socket _socket; | private Socket _socket; | ||||
private SslStream _sslStream; | private SslStream _sslStream; | ||||
/// <summary> | |||||
/// called on client sockets are created in connect | |||||
/// </summary> | |||||
public MqttTcpChannel() | public MqttTcpChannel() | ||||
{ | { | ||||
} | } | ||||
/// <summary> | |||||
/// called on server, sockets are passed in | |||||
/// connect will not be called | |||||
/// </summary> | |||||
public MqttTcpChannel(Socket socket, SslStream sslStream) | public MqttTcpChannel(Socket socket, SslStream sslStream) | ||||
{ | { | ||||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | _socket = socket ?? throw new ArgumentNullException(nameof(socket)); | ||||
_sslStream = sslStream; | _sslStream = sslStream; | ||||
_dataStream = (Stream)sslStream ?? new NetworkStream(socket); | |||||
} | } | ||||
public async Task ConnectAsync(MqttClientOptions options) | public async Task ConnectAsync(MqttClientOptions options) | ||||
@@ -40,7 +51,13 @@ namespace MQTTnet.Implementations | |||||
if (options.TlsOptions.UseTls) | if (options.TlsOptions.UseTls) | ||||
{ | { | ||||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | _sslStream = new SslStream(new NetworkStream(_socket, true)); | ||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation); | |||||
_dataStream = _sslStream; | |||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||||
} | |||||
else | |||||
{ | |||||
_dataStream = new NetworkStream(_socket); | |||||
} | } | ||||
} | } | ||||
catch (SocketException exception) | catch (SocketException exception) | ||||
@@ -68,16 +85,7 @@ namespace MQTTnet.Implementations | |||||
try | try | ||||
{ | { | ||||
if (_sslStream != null) | |||||
{ | |||||
return _sslStream.WriteAsync(buffer, 0, buffer.Length); | |||||
} | |||||
return Task.Factory.FromAsync( | |||||
// ReSharper disable once AssignNullToNotNullAttribute | |||||
_socket.BeginSend(buffer, 0, buffer.Length, SocketFlags.None, null, null), | |||||
_socket.EndSend); | |||||
return _dataStream.WriteAsync(buffer, 0, buffer.Length); | |||||
} | } | ||||
catch (SocketException exception) | catch (SocketException exception) | ||||
{ | { | ||||
@@ -85,21 +93,26 @@ namespace MQTTnet.Implementations | |||||
} | } | ||||
} | } | ||||
public Task ReadAsync(byte[] buffer) | |||||
public async Task ReadAsync(byte[] buffer) | |||||
{ | { | ||||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | ||||
try | try | ||||
{ | { | ||||
if (_sslStream != null) | |||||
int totalBytes = 0; | |||||
do | |||||
{ | { | ||||
return _sslStream.ReadAsync(buffer, 0, buffer.Length); | |||||
} | |||||
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false); | |||||
return Task.Factory.FromAsync( | |||||
// ReSharper disable once AssignNullToNotNullAttribute | |||||
_socket.BeginReceive(buffer, 0, buffer.Length, SocketFlags.None, null, null), | |||||
_socket.EndReceive); | |||||
if (read == 0) | |||||
{ | |||||
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); | |||||
} | |||||
totalBytes += read; | |||||
} | |||||
while (totalBytes < buffer.Length); | |||||
} | } | ||||
catch (SocketException exception) | catch (SocketException exception) | ||||
{ | { | ||||
@@ -7,22 +7,32 @@ using System.Threading.Tasks; | |||||
using MQTTnet.Core.Channel; | using MQTTnet.Core.Channel; | ||||
using MQTTnet.Core.Client; | using MQTTnet.Core.Client; | ||||
using MQTTnet.Core.Exceptions; | using MQTTnet.Core.Exceptions; | ||||
using System.IO; | |||||
namespace MQTTnet.Implementations | namespace MQTTnet.Implementations | ||||
{ | { | ||||
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable | public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable | ||||
{ | { | ||||
private Stream _dataStream; | |||||
private Socket _socket; | private Socket _socket; | ||||
private SslStream _sslStream; | private SslStream _sslStream; | ||||
/// <summary> | |||||
/// called on client sockets are created in connect | |||||
/// </summary> | |||||
public MqttTcpChannel() | public MqttTcpChannel() | ||||
{ | { | ||||
} | } | ||||
/// <summary> | |||||
/// called on server, sockets are passed in | |||||
/// connect will not be called | |||||
/// </summary> | |||||
public MqttTcpChannel(Socket socket, SslStream sslStream) | public MqttTcpChannel(Socket socket, SslStream sslStream) | ||||
{ | { | ||||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | _socket = socket ?? throw new ArgumentNullException(nameof(socket)); | ||||
_sslStream = sslStream; | _sslStream = sslStream; | ||||
_dataStream = (Stream)sslStream ?? new NetworkStream(socket); | |||||
} | } | ||||
public async Task ConnectAsync(MqttClientOptions options) | public async Task ConnectAsync(MqttClientOptions options) | ||||
@@ -35,12 +45,17 @@ namespace MQTTnet.Implementations | |||||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | ||||
} | } | ||||
await _socket.ConnectAsync(options.Server, options.GetPort()); | |||||
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); | |||||
if (options.TlsOptions.UseTls) | if (options.TlsOptions.UseTls) | ||||
{ | { | ||||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | _sslStream = new SslStream(new NetworkStream(_socket, true)); | ||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation); | |||||
_dataStream = _sslStream; | |||||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||||
} | |||||
else | |||||
{ | |||||
_dataStream = new NetworkStream(_socket); | |||||
} | } | ||||
} | } | ||||
catch (SocketException exception) | catch (SocketException exception) | ||||
@@ -68,12 +83,7 @@ namespace MQTTnet.Implementations | |||||
try | try | ||||
{ | { | ||||
if (_sslStream != null) | |||||
{ | |||||
return _sslStream.WriteAsync(buffer, 0, buffer.Length); | |||||
} | |||||
return _socket.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None); | |||||
return _dataStream.WriteAsync(buffer, 0, buffer.Length); | |||||
} | } | ||||
catch (SocketException exception) | catch (SocketException exception) | ||||
{ | { | ||||
@@ -81,18 +91,26 @@ namespace MQTTnet.Implementations | |||||
} | } | ||||
} | } | ||||
public Task ReadAsync(byte[] buffer) | |||||
public async Task ReadAsync(byte[] buffer) | |||||
{ | { | ||||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | ||||
try | try | ||||
{ | { | ||||
if (_sslStream != null) | |||||
int totalBytes = 0; | |||||
do | |||||
{ | { | ||||
return _sslStream.ReadAsync(buffer, 0, buffer.Length); | |||||
} | |||||
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false); | |||||
return _socket.ReceiveAsync(new ArraySegment<byte>(buffer), SocketFlags.None); | |||||
if (read == 0) | |||||
{ | |||||
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); | |||||
} | |||||
totalBytes += read; | |||||
} | |||||
while (totalBytes < buffer.Length); | |||||
} | } | ||||
catch (SocketException exception) | catch (SocketException exception) | ||||
{ | { | ||||