Browse Source

fixed bug in read async and unfied stream handling

release/3.x.x
JanEggers 7 years ago
parent
commit
caa857f318
2 changed files with 43 additions and 30 deletions
  1. +21
    -18
      Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs
  2. +22
    -12
      Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs

+ 21
- 18
Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs View File

@@ -7,11 +7,13 @@ using System.Threading.Tasks;
using MQTTnet.Core.Channel;
using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions;
using System.IO;

namespace MQTTnet.Implementations
{
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{
private Stream _dataStream;
private Socket _socket;
private SslStream _sslStream;

@@ -40,8 +42,13 @@ namespace MQTTnet.Implementations
if (options.TlsOptions.UseTls)
{
_sslStream = new SslStream(new NetworkStream(_socket, true));
_dataStream = _sslStream;
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation);
}
else
{
_dataStream = new NetworkStream(_socket);
}
}
catch (SocketException exception)
{
@@ -68,16 +75,7 @@ namespace MQTTnet.Implementations

try
{
if (_sslStream != null)
{
return _sslStream.WriteAsync(buffer, 0, buffer.Length);
}

return Task.Factory.FromAsync(
// ReSharper disable once AssignNullToNotNullAttribute
_socket.BeginSend(buffer, 0, buffer.Length, SocketFlags.None, null, null),
_socket.EndSend);

return _dataStream.WriteAsync(buffer, 0, buffer.Length);
}
catch (SocketException exception)
{
@@ -85,21 +83,26 @@ namespace MQTTnet.Implementations
}
}

public Task ReadAsync(byte[] buffer)
public async Task ReadAsync(byte[] buffer)
{
if (buffer == null) throw new ArgumentNullException(nameof(buffer));

try
{
if (_sslStream != null)
int totalBytes = 0;

do
{
return _sslStream.ReadAsync(buffer, 0, buffer.Length);
}
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes);

return Task.Factory.FromAsync(
// ReSharper disable once AssignNullToNotNullAttribute
_socket.BeginReceive(buffer, 0, buffer.Length, SocketFlags.None, null, null),
_socket.EndReceive);
if (read == 0)
{
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting));
}

totalBytes += read;
}
while (totalBytes < buffer.Length);
}
catch (SocketException exception)
{


+ 22
- 12
Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs View File

@@ -7,11 +7,13 @@ using System.Threading.Tasks;
using MQTTnet.Core.Channel;
using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions;
using System.IO;

namespace MQTTnet.Implementations
{
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{
private Stream _dataStream;
private Socket _socket;
private SslStream _sslStream;

@@ -36,12 +38,17 @@ namespace MQTTnet.Implementations
}

await _socket.ConnectAsync(options.Server, options.GetPort());
if (options.TlsOptions.UseTls)
{
_sslStream = new SslStream(new NetworkStream(_socket, true));
_dataStream = _sslStream;
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation);
}
else
{
_dataStream = new NetworkStream(_socket);
}
}
catch (SocketException exception)
{
@@ -68,12 +75,7 @@ namespace MQTTnet.Implementations

try
{
if (_sslStream != null)
{
return _sslStream.WriteAsync(buffer, 0, buffer.Length);
}

return _socket.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None);
return _dataStream.WriteAsync(buffer, 0, buffer.Length);
}
catch (SocketException exception)
{
@@ -81,18 +83,26 @@ namespace MQTTnet.Implementations
}
}

public Task ReadAsync(byte[] buffer)
public async Task ReadAsync(byte[] buffer)
{
if (buffer == null) throw new ArgumentNullException(nameof(buffer));

try
{
if (_sslStream != null)
int totalBytes = 0;

do
{
return _sslStream.ReadAsync(buffer, 0, buffer.Length);
}
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes);

return _socket.ReceiveAsync(new ArraySegment<byte>(buffer), SocketFlags.None);
if (read == 0)
{
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting));
}

totalBytes += read;
}
while (totalBytes < buffer.Length);
}
catch (SocketException exception)
{


Loading…
Cancel
Save