diff --git a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs index c9c080d..037538c 100644 --- a/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs @@ -7,11 +7,13 @@ using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; using MQTTnet.Core.Exceptions; +using System.IO; namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable { + private Stream _dataStream; private Socket _socket; private SslStream _sslStream; @@ -40,8 +42,13 @@ namespace MQTTnet.Implementations if (options.TlsOptions.UseTls) { _sslStream = new SslStream(new NetworkStream(_socket, true)); + _dataStream = _sslStream; await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation); } + else + { + _dataStream = new NetworkStream(_socket); + } } catch (SocketException exception) { @@ -68,16 +75,7 @@ namespace MQTTnet.Implementations try { - if (_sslStream != null) - { - return _sslStream.WriteAsync(buffer, 0, buffer.Length); - } - - return Task.Factory.FromAsync( - // ReSharper disable once AssignNullToNotNullAttribute - _socket.BeginSend(buffer, 0, buffer.Length, SocketFlags.None, null, null), - _socket.EndSend); - + return _dataStream.WriteAsync(buffer, 0, buffer.Length); } catch (SocketException exception) { @@ -85,21 +83,26 @@ namespace MQTTnet.Implementations } } - public Task ReadAsync(byte[] buffer) + public async Task ReadAsync(byte[] buffer) { if (buffer == null) throw new ArgumentNullException(nameof(buffer)); try { - if (_sslStream != null) + int totalBytes = 0; + + do { - return _sslStream.ReadAsync(buffer, 0, buffer.Length); - } + var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes); - return Task.Factory.FromAsync( - // ReSharper disable once AssignNullToNotNullAttribute - _socket.BeginReceive(buffer, 0, buffer.Length, SocketFlags.None, null, null), - _socket.EndReceive); + if (read == 0) + { + throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); + } + + totalBytes += read; + } + while (totalBytes < buffer.Length); } catch (SocketException exception) { diff --git a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs index a4247b0..7f59a27 100644 --- a/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs +++ b/Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs @@ -7,11 +7,13 @@ using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; using MQTTnet.Core.Exceptions; +using System.IO; namespace MQTTnet.Implementations { public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable { + private Stream _dataStream; private Socket _socket; private SslStream _sslStream; @@ -36,12 +38,17 @@ namespace MQTTnet.Implementations } await _socket.ConnectAsync(options.Server, options.GetPort()); - + if (options.TlsOptions.UseTls) { _sslStream = new SslStream(new NetworkStream(_socket, true)); + _dataStream = _sslStream; await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation); } + else + { + _dataStream = new NetworkStream(_socket); + } } catch (SocketException exception) { @@ -68,12 +75,7 @@ namespace MQTTnet.Implementations try { - if (_sslStream != null) - { - return _sslStream.WriteAsync(buffer, 0, buffer.Length); - } - - return _socket.SendAsync(new ArraySegment(buffer), SocketFlags.None); + return _dataStream.WriteAsync(buffer, 0, buffer.Length); } catch (SocketException exception) { @@ -81,18 +83,26 @@ namespace MQTTnet.Implementations } } - public Task ReadAsync(byte[] buffer) + public async Task ReadAsync(byte[] buffer) { if (buffer == null) throw new ArgumentNullException(nameof(buffer)); try { - if (_sslStream != null) + int totalBytes = 0; + + do { - return _sslStream.ReadAsync(buffer, 0, buffer.Length); - } + var read = await _dataStream.ReadAsync(buffer, totalBytes, buffer.Length - totalBytes); - return _socket.ReceiveAsync(new ArraySegment(buffer), SocketFlags.None); + if (read == 0) + { + throw new MqttCommunicationException(new SocketException((int)SocketError.Disconnecting)); + } + + totalBytes += read; + } + while (totalBytes < buffer.Length); } catch (SocketException exception) {