Browse Source

Merge pull request #34 from JanEggers/Fix_Bug33

fixed bug in read async and unfied stream handling
release/3.x.x
Christian 7 years ago
committed by GitHub
parent
commit
4a6790f0aa
2 changed files with 64 additions and 33 deletions
  1. +32
    -19
      Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs
  2. +32
    -14
      Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs

+ 32
- 19
Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs View File

@@ -7,22 +7,33 @@ using System.Threading.Tasks;
using MQTTnet.Core.Channel; using MQTTnet.Core.Channel;
using MQTTnet.Core.Client; using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions; using MQTTnet.Core.Exceptions;
using System.IO;


namespace MQTTnet.Implementations namespace MQTTnet.Implementations
{ {
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{ {
private Stream _dataStream;
private Socket _socket; private Socket _socket;
private SslStream _sslStream; private SslStream _sslStream;


/// <summary>
/// called on client sockets are created in connect
/// </summary>
public MqttTcpChannel() public MqttTcpChannel()
{ {
} }


/// <summary>
/// called on server, sockets are passed in
/// connect will not be called
/// </summary>
public MqttTcpChannel(Socket socket, SslStream sslStream) public MqttTcpChannel(Socket socket, SslStream sslStream)
{ {
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); _socket = socket ?? throw new ArgumentNullException(nameof(socket));
_sslStream = sslStream; _sslStream = sslStream;
_dataStream = (Stream)sslStream ?? new NetworkStream(socket);
} }


public async Task ConnectAsync(MqttClientOptions options) public async Task ConnectAsync(MqttClientOptions options)
@@ -40,7 +51,13 @@ namespace MQTTnet.Implementations
if (options.TlsOptions.UseTls) if (options.TlsOptions.UseTls)
{ {
_sslStream = new SslStream(new NetworkStream(_socket, true)); _sslStream = new SslStream(new NetworkStream(_socket, true));
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation);
_dataStream = _sslStream;
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false);
}
else
{
_dataStream = new NetworkStream(_socket);
} }
} }
catch (SocketException exception) catch (SocketException exception)
@@ -68,16 +85,7 @@ namespace MQTTnet.Implementations


try try
{ {
if (_sslStream != null)
{
return _sslStream.WriteAsync(buffer, 0, buffer.Length);
}

return Task.Factory.FromAsync(
// ReSharper disable once AssignNullToNotNullAttribute
_socket.BeginSend(buffer, 0, buffer.Length, SocketFlags.None, null, null),
_socket.EndSend);

return _dataStream.WriteAsync(buffer, 0, buffer.Length);
} }
catch (SocketException exception) catch (SocketException exception)
{ {
@@ -85,21 +93,26 @@ namespace MQTTnet.Implementations
} }
} }


public Task ReadAsync(byte[] buffer)
public async Task ReadAsync(byte[] buffer)
{ {
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); if (buffer == null) throw new ArgumentNullException(nameof(buffer));


try try
{ {
if (_sslStream != null)
int totalBytes = 0;

do
{ {
return _sslStream.ReadAsync(buffer, 0, buffer.Length);
}
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false);


return Task.Factory.FromAsync(
// ReSharper disable once AssignNullToNotNullAttribute
_socket.BeginReceive(buffer, 0, buffer.Length, SocketFlags.None, null, null),
_socket.EndReceive);
if (read == 0)
{
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting));
}

totalBytes += read;
}
while (totalBytes < buffer.Length);
} }
catch (SocketException exception) catch (SocketException exception)
{ {


+ 32
- 14
Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs View File

@@ -7,22 +7,32 @@ using System.Threading.Tasks;
using MQTTnet.Core.Channel; using MQTTnet.Core.Channel;
using MQTTnet.Core.Client; using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions; using MQTTnet.Core.Exceptions;
using System.IO;


namespace MQTTnet.Implementations namespace MQTTnet.Implementations
{ {
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{ {
private Stream _dataStream;
private Socket _socket; private Socket _socket;
private SslStream _sslStream; private SslStream _sslStream;


/// <summary>
/// called on client sockets are created in connect
/// </summary>
public MqttTcpChannel() public MqttTcpChannel()
{ {
} }


/// <summary>
/// called on server, sockets are passed in
/// connect will not be called
/// </summary>
public MqttTcpChannel(Socket socket, SslStream sslStream) public MqttTcpChannel(Socket socket, SslStream sslStream)
{ {
_socket = socket ?? throw new ArgumentNullException(nameof(socket)); _socket = socket ?? throw new ArgumentNullException(nameof(socket));
_sslStream = sslStream; _sslStream = sslStream;
_dataStream = (Stream)sslStream ?? new NetworkStream(socket);
} }


public async Task ConnectAsync(MqttClientOptions options) public async Task ConnectAsync(MqttClientOptions options)
@@ -35,12 +45,17 @@ namespace MQTTnet.Implementations
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp); _socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
} }


await _socket.ConnectAsync(options.Server, options.GetPort());
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false);
if (options.TlsOptions.UseTls) if (options.TlsOptions.UseTls)
{ {
_sslStream = new SslStream(new NetworkStream(_socket, true)); _sslStream = new SslStream(new NetworkStream(_socket, true));
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation);
_dataStream = _sslStream;
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false);
}
else
{
_dataStream = new NetworkStream(_socket);
} }
} }
catch (SocketException exception) catch (SocketException exception)
@@ -68,12 +83,7 @@ namespace MQTTnet.Implementations


try try
{ {
if (_sslStream != null)
{
return _sslStream.WriteAsync(buffer, 0, buffer.Length);
}

return _socket.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None);
return _dataStream.WriteAsync(buffer, 0, buffer.Length);
} }
catch (SocketException exception) catch (SocketException exception)
{ {
@@ -81,18 +91,26 @@ namespace MQTTnet.Implementations
} }
} }


public Task ReadAsync(byte[] buffer)
public async Task ReadAsync(byte[] buffer)
{ {
if (buffer == null) throw new ArgumentNullException(nameof(buffer)); if (buffer == null) throw new ArgumentNullException(nameof(buffer));


try try
{ {
if (_sslStream != null)
int totalBytes = 0;

do
{ {
return _sslStream.ReadAsync(buffer, 0, buffer.Length);
}
var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes).ConfigureAwait(false);


return _socket.ReceiveAsync(new ArraySegment<byte>(buffer), SocketFlags.None);
if (read == 0)
{
throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting));
}

totalBytes += read;
}
while (totalBytes < buffer.Length);
} }
catch (SocketException exception) catch (SocketException exception)
{ {


Loading…
Cancel
Save