fixed bug in read async and unfied stream handlingrelease/3.x.x
@@ -7,22 +7,33 @@ using System.Threading.Tasks; | |||
using MQTTnet.Core.Channel; | |||
using MQTTnet.Core.Client; | |||
using MQTTnet.Core.Exceptions; | |||
using System.IO; | |||
namespace MQTTnet.Implementations | |||
{ | |||
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable | |||
{ | |||
private Stream _dataStream; | |||
private Socket _socket; | |||
private SslStream _sslStream; | |||
/// <summary> | |||
/// called on client sockets are created in connect | |||
/// </summary> | |||
public MqttTcpChannel() | |||
{ | |||
} | |||
/// <summary> | |||
/// called on server, sockets are passed in | |||
/// connect will not be called | |||
/// </summary> | |||
public MqttTcpChannel(Socket socket, SslStream sslStream) | |||
{ | |||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | |||
_sslStream = sslStream; | |||
_dataStream = (Stream)sslStream ?? new NetworkStream(socket); | |||
} | |||
public async Task ConnectAsync(MqttClientOptions options) | |||
@@ -40,7 +51,13 @@ namespace MQTTnet.Implementations | |||
if (options.TlsOptions.UseTls) | |||
{ | |||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation); | |||
_dataStream = _sslStream; | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
_dataStream = new NetworkStream(_socket); | |||
} | |||
} | |||
catch (SocketException exception) | |||
@@ -68,16 +85,7 @@ namespace MQTTnet.Implementations | |||
try | |||
{ | |||
if (_sslStream != null) | |||
{ | |||
return _sslStream.WriteAsync(buffer, 0, buffer.Length); | |||
} | |||
return Task.Factory.FromAsync( | |||
// ReSharper disable once AssignNullToNotNullAttribute | |||
_socket.BeginSend(buffer, 0, buffer.Length, SocketFlags.None, null, null), | |||
_socket.EndSend); | |||
return _dataStream.WriteAsync(buffer, 0, buffer.Length); | |||
} | |||
catch (SocketException exception) | |||
{ | |||
@@ -85,21 +93,26 @@ namespace MQTTnet.Implementations | |||
} | |||
} | |||
public Task ReadAsync(byte[] buffer) | |||
public async Task ReadAsync(byte[] buffer) | |||
{ | |||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | |||
try | |||
{ | |||
if (_sslStream != null) | |||
int totalBytes = 0; | |||
do | |||
{ | |||
return _sslStream.ReadAsync(buffer, 0, buffer.Length); | |||
} | |||
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false); | |||
return Task.Factory.FromAsync( | |||
// ReSharper disable once AssignNullToNotNullAttribute | |||
_socket.BeginReceive(buffer, 0, buffer.Length, SocketFlags.None, null, null), | |||
_socket.EndReceive); | |||
if (read == 0) | |||
{ | |||
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); | |||
} | |||
totalBytes += read; | |||
} | |||
while (totalBytes < buffer.Length); | |||
} | |||
catch (SocketException exception) | |||
{ | |||
@@ -7,22 +7,32 @@ using System.Threading.Tasks; | |||
using MQTTnet.Core.Channel; | |||
using MQTTnet.Core.Client; | |||
using MQTTnet.Core.Exceptions; | |||
using System.IO; | |||
namespace MQTTnet.Implementations | |||
{ | |||
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable | |||
{ | |||
private Stream _dataStream; | |||
private Socket _socket; | |||
private SslStream _sslStream; | |||
/// <summary> | |||
/// called on client sockets are created in connect | |||
/// </summary> | |||
public MqttTcpChannel() | |||
{ | |||
} | |||
/// <summary> | |||
/// called on server, sockets are passed in | |||
/// connect will not be called | |||
/// </summary> | |||
public MqttTcpChannel(Socket socket, SslStream sslStream) | |||
{ | |||
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); | |||
_sslStream = sslStream; | |||
_dataStream = (Stream)sslStream ?? new NetworkStream(socket); | |||
} | |||
public async Task ConnectAsync(MqttClientOptions options) | |||
@@ -35,12 +45,17 @@ namespace MQTTnet.Implementations | |||
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); | |||
} | |||
await _socket.ConnectAsync(options.Server, options.GetPort()); | |||
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false); | |||
if (options.TlsOptions.UseTls) | |||
{ | |||
_sslStream = new SslStream(new NetworkStream(_socket, true)); | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation); | |||
_dataStream = _sslStream; | |||
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false); | |||
} | |||
else | |||
{ | |||
_dataStream = new NetworkStream(_socket); | |||
} | |||
} | |||
catch (SocketException exception) | |||
@@ -68,12 +83,7 @@ namespace MQTTnet.Implementations | |||
try | |||
{ | |||
if (_sslStream != null) | |||
{ | |||
return _sslStream.WriteAsync(buffer, 0, buffer.Length); | |||
} | |||
return _socket.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None); | |||
return _dataStream.WriteAsync(buffer, 0, buffer.Length); | |||
} | |||
catch (SocketException exception) | |||
{ | |||
@@ -81,18 +91,26 @@ namespace MQTTnet.Implementations | |||
} | |||
} | |||
public Task ReadAsync(byte[] buffer) | |||
public async Task ReadAsync(byte[] buffer) | |||
{ | |||
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); | |||
try | |||
{ | |||
if (_sslStream != null) | |||
int totalBytes = 0; | |||
do | |||
{ | |||
return _sslStream.ReadAsync(buffer, 0, buffer.Length); | |||
} | |||
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false); | |||
return _socket.ReceiveAsync(new ArraySegment<byte>(buffer), SocketFlags.None); | |||
if (read == 0) | |||
{ | |||
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); | |||
} | |||
totalBytes += read; | |||
} | |||
while (totalBytes < buffer.Length); | |||
} | |||
catch (SocketException exception) | |||
{ | |||