using System; using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using MQTTnet.Core.Serializer; using Microsoft.Extensions.Logging; namespace MQTTnet.Core.Adapter { public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter { private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); private readonly ILogger _logger; private readonly IMqttCommunicationChannel _channel; public MqttChannelCommunicationAdapter(IMqttCommunicationChannel channel, IMqttPacketSerializer serializer, ILogger logger) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _channel = channel ?? throw new ArgumentNullException(nameof(channel)); PacketSerializer = serializer ?? throw new ArgumentNullException(nameof(serializer)); } public IMqttPacketSerializer PacketSerializer { get; } public async Task ConnectAsync(TimeSpan timeout) { try { await _channel.ConnectAsync().TimeoutAfter(timeout).ConfigureAwait(false); } catch (TaskCanceledException) { throw; } catch (OperationCanceledException) { throw; } catch (MqttCommunicationTimedOutException) { throw; } catch (MqttCommunicationException) { throw; } catch (Exception exception) { throw new MqttCommunicationException(exception); } } public async Task DisconnectAsync(TimeSpan timeout) { try { await _channel.DisconnectAsync().TimeoutAfter(timeout).ConfigureAwait(false); } catch (TaskCanceledException) { throw; } catch (OperationCanceledException) { throw; } catch (MqttCommunicationTimedOutException) { throw; } catch (MqttCommunicationException) { throw; } catch (Exception exception) { throw new MqttCommunicationException(exception); } } public async Task SendPacketsAsync(TimeSpan timeout, CancellationToken cancellationToken, IEnumerable packets) { try { await _semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); foreach (var packet in packets) { if (packet == null) { continue; } _logger.LogInformation("TX >>> {0} [Timeout={1}]", packet, timeout); var writeBuffer = PacketSerializer.Serialize(packet); await _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false); } if (timeout > TimeSpan.Zero) { await _channel.SendStream.FlushAsync(cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); } else { await _channel.SendStream.FlushAsync(cancellationToken).ConfigureAwait(false); } } catch (TaskCanceledException) { throw; } catch (OperationCanceledException) { throw; } catch (MqttCommunicationTimedOutException) { throw; } catch (MqttCommunicationException) { throw; } catch (Exception exception) { throw new MqttCommunicationException(exception); } finally { _semaphore.Release(); } } public async Task ReceivePacketAsync(TimeSpan timeout, CancellationToken cancellationToken) { ReceivedMqttPacket receivedMqttPacket = null; try { if (timeout > TimeSpan.Zero) { receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).TimeoutAfter(timeout).ConfigureAwait(false); } else { receivedMqttPacket = await ReceiveAsync(_channel.ReceiveStream, cancellationToken).ConfigureAwait(false); } if (receivedMqttPacket == null || cancellationToken.IsCancellationRequested) { throw new TaskCanceledException(); } var packet = PacketSerializer.Deserialize(receivedMqttPacket); if (packet == null) { throw new MqttProtocolViolationException("Received malformed packet."); } _logger.LogInformation("RX <<< {0}", packet); return packet; } catch (TaskCanceledException) { throw; } catch (OperationCanceledException) { throw; } catch (MqttCommunicationTimedOutException) { throw; } catch (MqttCommunicationException) { throw; } catch (Exception exception) { throw new MqttCommunicationException(exception); } finally { receivedMqttPacket?.Dispose(); } } private static async Task ReceiveAsync(Stream stream, CancellationToken cancellationToken) { var header = MqttPacketReader.ReadHeaderFromSource(stream, cancellationToken); if (header == null) { return null; } if (header.BodyLength == 0) { return new ReceivedMqttPacket(header, new MemoryStream(0)); } var body = new byte[header.BodyLength]; var offset = 0; do { var readBytesCount = await stream.ReadAsync(body, offset, body.Length - offset, cancellationToken).ConfigureAwait(false); offset += readBytesCount; } while (offset < header.BodyLength); return new ReceivedMqttPacket(header, new MemoryStream(body, 0, body.Length)); } } }