@@ -72,11 +72,7 @@ namespace MQTTnet.Implementations | |||
// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 | |||
using (cancellationToken.Register(() => socket.Dispose())) | |||
{ | |||
#if NET452 || NET461 | |||
await Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false); | |||
#else | |||
await socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); | |||
#endif | |||
await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false); | |||
} | |||
var networkStream = new NetworkStream(socket, true); | |||
@@ -117,6 +113,10 @@ namespace MQTTnet.Implementations | |||
return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | |||
} | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
return 0; | |||
} | |||
catch (IOException exception) | |||
{ | |||
if (exception.InnerException is SocketException socketException) | |||
@@ -143,6 +143,10 @@ namespace MQTTnet.Implementations | |||
await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | |||
} | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
return; | |||
} | |||
catch (IOException exception) | |||
{ | |||
if (exception.InnerException is SocketException socketException) | |||
@@ -107,12 +107,7 @@ namespace MQTTnet.Implementations | |||
{ | |||
try | |||
{ | |||
#if NET452 || NET461 | |||
var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); | |||
#else | |||
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); | |||
#endif | |||
var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false); | |||
if (clientSocket == null) | |||
{ | |||
continue; | |||
@@ -0,0 +1,92 @@ | |||
using System; | |||
using System.Net; | |||
using System.Net.Sockets; | |||
using System.Threading.Tasks; | |||
namespace MQTTnet.Implementations | |||
{ | |||
public static class PlatformAbstractionLayer | |||
{ | |||
public static async Task<Socket> AcceptAsync(Socket socket) | |||
{ | |||
#if NET452 || NET461 | |||
try | |||
{ | |||
return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false); | |||
} | |||
catch (ObjectDisposedException) | |||
{ | |||
return null; | |||
} | |||
#else | |||
return await socket.AcceptAsync().ConfigureAwait(false); | |||
#endif | |||
} | |||
public static Task ConnectAsync(Socket socket, IPAddress ip, int port) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null); | |||
#else | |||
return socket.ConnectAsync(ip, port); | |||
#endif | |||
} | |||
public static Task ConnectAsync(Socket socket, string host, int port) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null); | |||
#else | |||
return socket.ConnectAsync(host, port); | |||
#endif | |||
} | |||
#if NET452 || NET461 | |||
public class SocketWrapper | |||
{ | |||
private readonly Socket _socket; | |||
private readonly ArraySegment<byte> _buffer; | |||
private readonly SocketFlags _socketFlags; | |||
public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
_socket = socket; | |||
_buffer = buffer; | |||
_socketFlags = socketFlags; | |||
} | |||
public static IAsyncResult BeginSend(AsyncCallback callback, object state) | |||
{ | |||
var real = (SocketWrapper)state; | |||
return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); | |||
} | |||
public static IAsyncResult BeginReceive(AsyncCallback callback, object state) | |||
{ | |||
var real = (SocketWrapper)state; | |||
return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); | |||
} | |||
} | |||
#endif | |||
public static Task SendAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags)); | |||
#else | |||
return socket.SendAsync(buffer, socketFlags); | |||
#endif | |||
} | |||
public static Task<int> ReceiveAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||
{ | |||
#if NET452 || NET461 | |||
return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags)); | |||
#else | |||
return socket.ReceiveAsync(buffer, socketFlags); | |||
#endif | |||
} | |||
} | |||
} |
@@ -1,7 +1,7 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFramework>netcoreapp2.2</TargetFramework> | |||
<TargetFrameworks>netcoreapp2.2;net461</TargetFrameworks> | |||
<IsPackable>false</IsPackable> | |||
</PropertyGroup> | |||
@@ -39,7 +39,7 @@ namespace MQTTnet.Tests.Mockups | |||
case MqttClientOptionsBuilder builder: | |||
{ | |||
var existingClientId = builder.Build().ClientId; | |||
if (!existingClientId.StartsWith(TestContext.TestName)) | |||
if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) | |||
{ | |||
builder.WithClientId(TestContext.TestName + existingClientId); | |||
} | |||
@@ -48,7 +48,7 @@ namespace MQTTnet.Tests.Mockups | |||
case MqttClientOptions op: | |||
{ | |||
var existingClientId = op.ClientId; | |||
if (!existingClientId.StartsWith(TestContext.TestName)) | |||
if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) | |||
{ | |||
op.ClientId = TestContext.TestName + existingClientId; | |||
} | |||
@@ -237,6 +237,36 @@ namespace MQTTnet.Tests | |||
} | |||
} | |||
[TestMethod] | |||
public async Task ConnectTimeout_Throws_Exception() | |||
{ | |||
var factory = new MqttFactory(); | |||
using (var client = factory.CreateMqttClient()) | |||
{ | |||
bool disconnectHandlerCalled = false; | |||
try | |||
{ | |||
client.DisconnectedHandler = new MqttClientDisconnectedHandlerDelegate(args => | |||
{ | |||
disconnectHandlerCalled = true; | |||
}); | |||
await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("1.2.3.4").Build()); | |||
Assert.Fail("Must fail!"); | |||
} | |||
catch (Exception exception) | |||
{ | |||
Assert.IsNotNull(exception); | |||
Assert.IsInstanceOfType(exception, typeof(MqttCommunicationException)); | |||
//Assert.IsInstanceOfType(exception.InnerException, typeof(SocketException)); | |||
} | |||
await Task.Delay(100); // disconnected handler is called async | |||
Assert.IsTrue(disconnectHandlerCalled); | |||
} | |||
} | |||
[TestMethod] | |||
public async Task Fire_Disconnected_Event_On_Server_Shutdown() | |||
{ | |||
@@ -28,14 +28,14 @@ namespace MQTTnet.Tests | |||
{ | |||
while (!ct.IsCancellationRequested) | |||
{ | |||
var client = await serverSocket.AcceptAsync(); | |||
var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket); | |||
var data = new byte[] { 128 }; | |||
await client.SendAsync(new ArraySegment<byte>(data), SocketFlags.None); | |||
await PlatformAbstractionLayer.SendAsync(client, new ArraySegment<byte>(data), SocketFlags.None); | |||
} | |||
}, ct.Token); | |||
var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |||
await clientSocket.ConnectAsync(IPAddress.Loopback, 50001); | |||
await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001); | |||
await Task.Delay(100, ct.Token); | |||
@@ -13,6 +13,7 @@ using MQTTnet.Client.Disconnecting; | |||
using MQTTnet.Client.Options; | |||
using MQTTnet.Client.Receiving; | |||
using MQTTnet.Client.Subscribing; | |||
using MQTTnet.Implementations; | |||
using MQTTnet.Protocol; | |||
using MQTTnet.Server; | |||
using MQTTnet.Tests.Mockups; | |||
@@ -1140,7 +1141,7 @@ namespace MQTTnet.Tests | |||
await testEnvironment.ConnectClientAsync(); | |||
} | |||
} | |||
[TestMethod] | |||
public async Task Close_Idle_Connection() | |||
{ | |||
@@ -1149,14 +1150,14 @@ namespace MQTTnet.Tests | |||
await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); | |||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |||
await client.ConnectAsync("localhost", testEnvironment.ServerPort); | |||
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); | |||
// Don't send anything. The server should close the connection. | |||
await Task.Delay(TimeSpan.FromSeconds(3)); | |||
try | |||
{ | |||
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
if (receivedBytes == 0) | |||
{ | |||
return; | |||
@@ -1180,7 +1181,7 @@ namespace MQTTnet.Tests | |||
// Send an invalid packet and ensure that the server will close the connection and stay in a waiting state | |||
// forever. This is security related. | |||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |||
await client.ConnectAsync("localhost", testEnvironment.ServerPort); | |||
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); | |||
var buffer = Encoding.UTF8.GetBytes("Garbage"); | |||
client.Send(buffer, buffer.Length, SocketFlags.None); | |||
@@ -1189,7 +1190,7 @@ namespace MQTTnet.Tests | |||
try | |||
{ | |||
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||
if (receivedBytes == 0) | |||
{ | |||
return; | |||