using System;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using MQTTnet.Core.Channel;
using MQTTnet.Core.Client;
using MQTTnet.Core.Exceptions;
using System.IO;
namespace MQTTnet.Implementations
{
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{
private Stream _dataStream;
private Socket _socket;
private SslStream _sslStream;
public Stream ReceiveStream => _dataStream;
public Stream RawStream => _dataStream;
public Stream SendStream => _dataStream;
///
/// called on client sockets are created in connect
///
public MqttTcpChannel()
{
}
///
/// called on server, sockets are passed in
/// connect will not be called
///
public MqttTcpChannel(Socket socket, SslStream sslStream)
{
_socket = socket ?? throw new ArgumentNullException(nameof(socket));
_sslStream = sslStream;
_dataStream = (Stream)sslStream ?? new NetworkStream(socket);
}
public async Task ConnectAsync(MqttClientOptions options)
{
if (options == null) throw new ArgumentNullException(nameof(options));
try
{
if (_socket == null)
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}
await _socket.ConnectAsync(options.Server, options.GetPort()).ConfigureAwait(false);
if (options.TlsOptions.UseTls)
{
_sslStream = new SslStream(new NetworkStream(_socket, true));
_dataStream = _sslStream;
await _sslStream.AuthenticateAsClientAsync(options.Server, LoadCertificates(options), SslProtocols.Tls12, options.TlsOptions.CheckCertificateRevocation).ConfigureAwait(false);
}
else
{
_dataStream = new NetworkStream(_socket);
}
}
catch (SocketException exception)
{
throw new MqttCommunicationException(exception);
}
}
public Task DisconnectAsync()
{
try
{
Dispose();
return Task.FromResult(0);
}
catch (SocketException exception)
{
throw new MqttCommunicationException(exception);
}
}
public void Dispose()
{
_socket?.Dispose();
_sslStream?.Dispose();
_socket = null;
_sslStream = null;
}
private static X509CertificateCollection LoadCertificates(MqttClientOptions options)
{
var certificates = new X509CertificateCollection();
if (options.TlsOptions.Certificates == null)
{
return certificates;
}
foreach (var certificate in options.TlsOptions.Certificates)
{
certificates.Add(new X509Certificate(certificate));
}
return certificates;
}
}
}