using System; using System.Collections.Generic; using System.IO; using System.Net.Sockets; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using MQTTnet.Channel; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Serializer; namespace MQTTnet.Adapter { public class MqttChannelAdapter : IMqttChannelAdapter { private const uint ErrorOperationAborted = 0x800703E3; private const int ReadBufferSize = 4096; // TODO: Move buffer size to config private readonly IMqttNetChildLogger _logger; private readonly IMqttChannel _channel; private bool _isDisposed; public MqttChannelAdapter(IMqttChannel channel, IMqttPacketSerializer serializer, IMqttNetChildLogger logger) { if (logger == null) throw new ArgumentNullException(nameof(logger)); _channel = channel ?? throw new ArgumentNullException(nameof(channel)); PacketSerializer = serializer ?? throw new ArgumentNullException(nameof(serializer)); _logger = logger.CreateChildLogger(nameof(MqttChannelAdapter)); } public string Endpoint => _channel.Endpoint; public IMqttPacketSerializer PacketSerializer { get; } public event EventHandler ReadingPacketStarted; public event EventHandler ReadingPacketCompleted; public Task ConnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); _logger.Verbose("Connecting [Timeout={0}]", timeout); return ExecuteAndWrapExceptionAsync(() => Internal.TaskExtensions.TimeoutAfter(ct => _channel.ConnectAsync(ct), timeout, cancellationToken)); } public Task DisconnectAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); _logger.Verbose("Disconnecting [Timeout={0}]", timeout); return ExecuteAndWrapExceptionAsync(() => Internal.TaskExtensions.TimeoutAfter(ct => _channel.DisconnectAsync(), timeout, cancellationToken)); } public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets, CancellationToken cancellationToken) { ThrowIfDisposed(); foreach (var packet in packets) { if (packet == null) { continue; } await SendPacketAsync(timeout, cancellationToken, packet).ConfigureAwait(false); } } private Task SendPacketAsync(TimeSpan timeout, CancellationToken cancellationToken, MqttBasePacket packet) { return ExecuteAndWrapExceptionAsync(() => { _logger.Verbose("TX >>> {0} [Timeout={1}]", packet, timeout); var packetData = PacketSerializer.Serialize(packet); return Internal.TaskExtensions.TimeoutAfter(ct => _channel.WriteAsync( packetData.Array, packetData.Offset, packetData.Count, ct), timeout, cancellationToken); }); } public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); MqttBasePacket packet = null; await ExecuteAndWrapExceptionAsync(async () => { ReceivedMqttPacket receivedMqttPacket = null; try { if (timeout > TimeSpan.Zero) { receivedMqttPacket = await Internal.TaskExtensions.TimeoutAfter(ct => ReceiveAsync(_channel, ct), timeout, cancellationToken).ConfigureAwait(false); } else { receivedMqttPacket = await ReceiveAsync(_channel, cancellationToken).ConfigureAwait(false); } if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) { throw new TaskCanceledException(); } packet = PacketSerializer.Deserialize(receivedMqttPacket); if (packet == null) { throw new MqttProtocolViolationException("Received malformed packet."); } _logger.Verbose("RX <<< {0}", packet); } finally { receivedMqttPacket?.Dispose(); } }).ConfigureAwait(false); return packet; } private async Task ReceiveAsync(IMqttChannel channel, CancellationToken cancellationToken) { var fixedHeader = await MqttPacketReader.ReadFixedHeaderAsync(channel, cancellationToken).ConfigureAwait(false); ReadingPacketStarted?.Invoke(this, EventArgs.Empty); try { var bodyLength = await MqttPacketReader.ReadBodyLengthAsync(channel, cancellationToken).ConfigureAwait(false); if (bodyLength == 0) { return new ReceivedMqttPacket(fixedHeader, null); } var body = new MemoryStream(bodyLength); var buffer = new byte[Math.Min(ReadBufferSize, bodyLength)]; while (body.Length < bodyLength) { var bytesLeft = bodyLength - (int)body.Length; if (bytesLeft > buffer.Length) { bytesLeft = buffer.Length; } var readBytesCount = await channel.ReadAsync(buffer, 0, bytesLeft, cancellationToken).ConfigureAwait(false); if (readBytesCount <= 0) { ExceptionHelper.ThrowGracefulSocketClose(); } // Here is no need to await because internally only an array is used and no real I/O operation is made. // Using async here will only generate overhead. body.Write(buffer, 0, readBytesCount); } body.Seek(0L, SeekOrigin.Begin); return new ReceivedMqttPacket(fixedHeader, body); } finally { ReadingPacketCompleted?.Invoke(this, EventArgs.Empty); } } private static async Task ExecuteAndWrapExceptionAsync(Func action) { try { await action().ConfigureAwait(false); } catch (TaskCanceledException) { throw; } catch (OperationCanceledException) { throw; } catch (MqttCommunicationTimedOutException) { throw; } catch (MqttCommunicationException) { throw; } catch (COMException comException) { if ((uint)comException.HResult == ErrorOperationAborted) { throw new OperationCanceledException(); } throw new MqttCommunicationException(comException); } catch (IOException exception) { if (exception.InnerException is SocketException socketException) { if (socketException.SocketErrorCode == SocketError.ConnectionAborted) { throw new OperationCanceledException(); } } throw new MqttCommunicationException(exception); } catch (Exception exception) { throw new MqttCommunicationException(exception); } } public void Dispose() { _isDisposed = true; _channel?.Dispose(); } private void ThrowIfDisposed() { if (_isDisposed) { throw new ObjectDisposedException(nameof(MqttChannelAdapter)); } } } }