using System; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; using MQTTnet.Core.Channel; using MQTTnet.Core.Client; using MQTTnet.Core.Diagnostics; using MQTTnet.Core.Exceptions; using MQTTnet.Core.Internal; using MQTTnet.Core.Packets; using MQTTnet.Core.Serializer; namespace MQTTnet.Core.Adapter { public class MqttChannelCommunicationAdapter : IMqttCommunicationAdapter { private readonly IMqttCommunicationChannel _channel; private readonly byte[] _readBuffer = new byte[BufferConstants.Size]; private Task _sendTask = Task.FromResult(0); // this task is used to prevent overlapping write public MqttChannelCommunicationAdapter(IMqttCommunicationChannel channel, IMqttPacketSerializer serializer) { _channel = channel ?? throw new ArgumentNullException(nameof(channel)); PacketSerializer = serializer ?? throw new ArgumentNullException(nameof(serializer)); } public IMqttPacketSerializer PacketSerializer { get; } public Task ConnectAsync(MqttClientOptions options, TimeSpan timeout) { return _channel.ConnectAsync(options).TimeoutAfter(timeout); } public Task DisconnectAsync() { return _channel.DisconnectAsync(); } public async Task SendPacketsAsync(TimeSpan timeout, IEnumerable packets) { lock (_channel) { foreach (var packet in packets) { MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "TX >>> {0} [Timeout={1}]", packet, timeout); var writeBuffer = PacketSerializer.Serialize(packet); _sendTask = _sendTask.ContinueWith(p => _channel.SendStream.WriteAsync(writeBuffer, 0, writeBuffer.Length)); } } await _sendTask; // configure await false geneates stackoverflow await _channel.SendStream.FlushAsync().TimeoutAfter(timeout).ConfigureAwait(false); } public async Task ReceivePacketAsync(TimeSpan timeout) { Tuple tuple; if (timeout > TimeSpan.Zero) { tuple = await ReceiveAsync(_channel.RawStream).TimeoutAfter(timeout).ConfigureAwait(false); } else { tuple = await ReceiveAsync(_channel.ReceiveStream).ConfigureAwait(false); } var packet = PacketSerializer.Deserialize(tuple.Item1, tuple.Item2); if (packet == null) { throw new MqttProtocolViolationException("Received malformed packet."); } MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), "RX <<< {0}", packet); return packet; } private async Task> ReceiveAsync(Stream stream) { var header = MqttPacketReader.ReadHeaderFromSource(stream); MemoryStream body; if (header.BodyLength > 0) { var totalRead = 0; do { var read = await stream.ReadAsync(_readBuffer, totalRead, header.BodyLength - totalRead).ConfigureAwait(false); totalRead += read; } while (totalRead < header.BodyLength); body = new MemoryStream(_readBuffer, 0, header.BodyLength); } else { body = new MemoryStream(); } return Tuple.Create(header, body); } } }