diff --git a/Source/MQTTnet.Extensions.Wrappers.WebSocket4Net/WebSocket4NetMqttClientAdapterFactory.cs b/Source/MQTTnet.Extensions.Wrappers.WebSocket4Net/WebSocket4NetMqttClientAdapterFactory.cs new file mode 100644 index 0000000..ed29be8 --- /dev/null +++ b/Source/MQTTnet.Extensions.Wrappers.WebSocket4Net/WebSocket4NetMqttClientAdapterFactory.cs @@ -0,0 +1,134 @@ +using MQTTnet.Client; +using System; +using System.Collections.Concurrent; +using System.Linq; +using System.Security.Authentication; +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Adapter; +using MQTTnet.Channel; +using MQTTnet.Diagnostics; +using MQTTnet.Serializer; +using WebSocket4Net; + +namespace MQTTnet.TestApp.NetCore +{ + public class WebSocket4NetMqttClientAdapterFactory : IMqttClientAdapterFactory + { + public IMqttChannelAdapter CreateClientAdapter(IMqttClientOptions options, IMqttNetChildLogger logger) + { + if (options == null) throw new ArgumentNullException(nameof(options)); + if (logger == null) throw new ArgumentNullException(nameof(logger)); + + if (!(options.ChannelOptions is MqttClientWebSocketOptions)) + { + throw new NotSupportedException("Only WebSocket connections are supported."); + } + + return new MqttChannelAdapter(new WebSocket4NetMqttChannel(options), new MqttPacketSerializer(), logger); + } + + private class WebSocket4NetMqttChannel : IMqttChannel + { + private readonly BlockingCollection _receiveBuffer = new BlockingCollection(); + + private readonly IMqttClientOptions _clientOptions; + private WebSocket4Net.WebSocket _webSocket; + + public WebSocket4NetMqttChannel(IMqttClientOptions clientOptions) + { + _clientOptions = clientOptions ?? throw new ArgumentNullException(nameof(clientOptions)); + } + + public string Endpoint { get; } = ""; + + public Task ConnectAsync(CancellationToken cancellationToken) + { + var channelOptions = (MqttClientWebSocketOptions)_clientOptions.ChannelOptions; + + var uri = "ws://" + channelOptions.Uri; + var sslProtocols = SslProtocols.None; + + if (channelOptions.TlsOptions.UseTls) + { + uri = "wss://" + channelOptions.Uri; + sslProtocols = SslProtocols.Tls12; + } + + var subProtocol = channelOptions.SubProtocols.FirstOrDefault() ?? string.Empty; + + _webSocket = new WebSocket4Net.WebSocket(uri, subProtocol, sslProtocols: sslProtocols); + _webSocket.DataReceived += OnDataReceived; + _webSocket.Open(); + SpinWait.SpinUntil(() => _webSocket.State == WebSocketState.Open, _clientOptions.CommunicationTimeout); + + return Task.FromResult(0); + } + + public Task DisconnectAsync() + { + if (_webSocket != null) + { + _webSocket.DataReceived -= OnDataReceived; + _webSocket.Close(); + SpinWait.SpinUntil(() => _webSocket.State == WebSocketState.Closed, _clientOptions.CommunicationTimeout); + } + + _webSocket = null; + + return Task.FromResult(0); + } + + public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var readBytes = 0; + while (count > 0 && !cancellationToken.IsCancellationRequested) + { + byte @byte; + if (readBytes == 0) + { + // Block until at lease one byte was received. + @byte = _receiveBuffer.Take(cancellationToken); + } + else + { + if (!_receiveBuffer.TryTake(out @byte)) + { + return Task.FromResult(readBytes); + } + } + + buffer[offset] = @byte; + offset++; + count--; + readBytes++; + } + + return Task.FromResult(readBytes); + } + + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + _webSocket.Send(buffer, offset, count); + return Task.FromResult(0); + } + + public void Dispose() + { + if (_webSocket != null) + { + _webSocket.DataReceived -= OnDataReceived; + _webSocket.Dispose(); + } + } + + private void OnDataReceived(object sender, WebSocket4Net.DataReceivedEventArgs e) + { + foreach (var @byte in e.Data) + { + _receiveBuffer.Add(@byte); + } + } + } + } +} diff --git a/Tests/MQTTnet.TestApp.NetCore/PublicBrokerTest.cs b/Tests/MQTTnet.TestApp.NetCore/PublicBrokerTest.cs index 1265a6a..e46ed16 100644 --- a/Tests/MQTTnet.TestApp.NetCore/PublicBrokerTest.cs +++ b/Tests/MQTTnet.TestApp.NetCore/PublicBrokerTest.cs @@ -3,6 +3,7 @@ using System; using System.IO; using System.Threading; using System.Threading.Tasks; +using MQTTnet.Diagnostics; using MQTTnet.Protocol; using Newtonsoft.Json; @@ -12,7 +13,7 @@ namespace MQTTnet.TestApp.NetCore { public static async Task RunAsync() { - //MqttNetGlobalLogger.LogMessagePublished += (s, e) => Console.WriteLine(e.TraceMessage); + //MqttNetConsoleLogger.ForwardToConsole(); // iot.eclipse.org await ExecuteTestAsync("iot.eclipse.org TCP", @@ -35,7 +36,7 @@ namespace MQTTnet.TestApp.NetCore new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8080/mqtt").Build()); await ExecuteTestAsync("test.mosquitto.org WS TLS", - new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8081/mqtt").Build()); + new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8081/mqtt").WithTls().Build()); // broker.hivemq.com await ExecuteTestAsync("broker.hivemq.com TCP", @@ -78,7 +79,7 @@ namespace MQTTnet.TestApp.NetCore var topic = Guid.NewGuid().ToString(); MqttApplicationMessage receivedMessage = null; - client.ApplicationMessageReceived += (s, e) => receivedMessage = e.ApplicationMessage; + client.ApplicationMessageReceived += (s, e) => receivedMessage = e.ApplicationMessage; await client.ConnectAsync(options); await client.SubscribeAsync(topic, MqttQualityOfServiceLevel.AtLeastOnce);