diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index 8f012cb..d500813 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -72,11 +72,7 @@ namespace MQTTnet.Implementations // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 using (cancellationToken.Register(() => socket.Dispose())) { -#if NET452 || NET461 - await Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false); -#else - await socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); -#endif + await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false); } var networkStream = new NetworkStream(socket, true); @@ -117,6 +113,10 @@ namespace MQTTnet.Implementations return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } } + catch (ObjectDisposedException) + { + return 0; + } catch (IOException exception) { if (exception.InnerException is SocketException socketException) @@ -143,6 +143,10 @@ namespace MQTTnet.Implementations await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } } + catch (ObjectDisposedException) + { + return; + } catch (IOException exception) { if (exception.InnerException is SocketException socketException) diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index d57888e..f2f439e 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -107,12 +107,7 @@ namespace MQTTnet.Implementations { try { -#if NET452 || NET461 - var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); -#else - var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); -#endif - + var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false); if (clientSocket == null) { continue; diff --git a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs new file mode 100644 index 0000000..a940eac --- /dev/null +++ b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs @@ -0,0 +1,92 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks; + +namespace MQTTnet.Implementations +{ + public static class PlatformAbstractionLayer + { + public static async Task AcceptAsync(Socket socket) + { +#if NET452 || NET461 + try + { + return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false); + } + catch (ObjectDisposedException) + { + return null; + } +#else + return await socket.AcceptAsync().ConfigureAwait(false); +#endif + } + + + public static Task ConnectAsync(Socket socket, IPAddress ip, int port) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null); +#else + return socket.ConnectAsync(ip, port); +#endif + } + + public static Task ConnectAsync(Socket socket, string host, int port) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null); +#else + return socket.ConnectAsync(host, port); +#endif + } + +#if NET452 || NET461 + public class SocketWrapper + { + private readonly Socket _socket; + private readonly ArraySegment _buffer; + private readonly SocketFlags _socketFlags; + + public SocketWrapper(Socket socket, ArraySegment buffer, SocketFlags socketFlags) + { + _socket = socket; + _buffer = buffer; + _socketFlags = socketFlags; + } + + public static IAsyncResult BeginSend(AsyncCallback callback, object state) + { + var real = (SocketWrapper)state; + return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); + } + + public static IAsyncResult BeginReceive(AsyncCallback callback, object state) + { + var real = (SocketWrapper)state; + return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); + } + } +#endif + + public static Task SendAsync(Socket socket, ArraySegment buffer, SocketFlags socketFlags) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags)); +#else + return socket.SendAsync(buffer, socketFlags); +#endif + } + + public static Task ReceiveAsync(Socket socket, ArraySegment buffer, SocketFlags socketFlags) + { +#if NET452 || NET461 + return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags)); +#else + return socket.ReceiveAsync(buffer, socketFlags); +#endif + } + + } +} diff --git a/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj b/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj index a4b6881..7bf14cb 100644 --- a/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj +++ b/Tests/MQTTnet.Core.Tests/MQTTnet.Tests.csproj @@ -1,7 +1,7 @@  - netcoreapp2.2 + netcoreapp2.2;net461 false diff --git a/Tests/MQTTnet.Core.Tests/Mockups/TestClientWrapper.cs b/Tests/MQTTnet.Core.Tests/Mockups/TestClientWrapper.cs index dc0e3a9..2500a6f 100644 --- a/Tests/MQTTnet.Core.Tests/Mockups/TestClientWrapper.cs +++ b/Tests/MQTTnet.Core.Tests/Mockups/TestClientWrapper.cs @@ -39,7 +39,7 @@ namespace MQTTnet.Tests.Mockups case MqttClientOptionsBuilder builder: { var existingClientId = builder.Build().ClientId; - if (!existingClientId.StartsWith(TestContext.TestName)) + if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) { builder.WithClientId(TestContext.TestName + existingClientId); } @@ -48,7 +48,7 @@ namespace MQTTnet.Tests.Mockups case MqttClientOptions op: { var existingClientId = op.ClientId; - if (!existingClientId.StartsWith(TestContext.TestName)) + if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) { op.ClientId = TestContext.TestName + existingClientId; } diff --git a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs index 9cd3024..b8f4fbc 100644 --- a/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttClient_Tests.cs @@ -237,6 +237,36 @@ namespace MQTTnet.Tests } } + [TestMethod] + public async Task ConnectTimeout_Throws_Exception() + { + var factory = new MqttFactory(); + using (var client = factory.CreateMqttClient()) + { + bool disconnectHandlerCalled = false; + try + { + client.DisconnectedHandler = new MqttClientDisconnectedHandlerDelegate(args => + { + disconnectHandlerCalled = true; + }); + + await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("1.2.3.4").Build()); + + Assert.Fail("Must fail!"); + } + catch (Exception exception) + { + Assert.IsNotNull(exception); + Assert.IsInstanceOfType(exception, typeof(MqttCommunicationException)); + //Assert.IsInstanceOfType(exception.InnerException, typeof(SocketException)); + } + + await Task.Delay(100); // disconnected handler is called async + Assert.IsTrue(disconnectHandlerCalled); + } + } + [TestMethod] public async Task Fire_Disconnected_Event_On_Server_Shutdown() { diff --git a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs index 436d2d1..a4b0ca7 100644 --- a/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttTcpChannel_Tests.cs @@ -28,14 +28,14 @@ namespace MQTTnet.Tests { while (!ct.IsCancellationRequested) { - var client = await serverSocket.AcceptAsync(); + var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket); var data = new byte[] { 128 }; - await client.SendAsync(new ArraySegment(data), SocketFlags.None); + await PlatformAbstractionLayer.SendAsync(client, new ArraySegment(data), SocketFlags.None); } }, ct.Token); var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await clientSocket.ConnectAsync(IPAddress.Loopback, 50001); + await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001); await Task.Delay(100, ct.Token); diff --git a/Tests/MQTTnet.Core.Tests/Server_Tests.cs b/Tests/MQTTnet.Core.Tests/Server_Tests.cs index cbbc1fc..755dafe 100644 --- a/Tests/MQTTnet.Core.Tests/Server_Tests.cs +++ b/Tests/MQTTnet.Core.Tests/Server_Tests.cs @@ -13,6 +13,7 @@ using MQTTnet.Client.Disconnecting; using MQTTnet.Client.Options; using MQTTnet.Client.Receiving; using MQTTnet.Client.Subscribing; +using MQTTnet.Implementations; using MQTTnet.Protocol; using MQTTnet.Server; using MQTTnet.Tests.Mockups; @@ -1140,7 +1141,7 @@ namespace MQTTnet.Tests await testEnvironment.ConnectClientAsync(); } } - + [TestMethod] public async Task Close_Idle_Connection() { @@ -1149,14 +1150,14 @@ namespace MQTTnet.Tests await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await client.ConnectAsync("localhost", testEnvironment.ServerPort); + await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); // Don't send anything. The server should close the connection. await Task.Delay(TimeSpan.FromSeconds(3)); try { - var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment(new byte[10]), SocketFlags.Partial); if (receivedBytes == 0) { return; @@ -1180,7 +1181,7 @@ namespace MQTTnet.Tests // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state // forever. This is security related. var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await client.ConnectAsync("localhost", testEnvironment.ServerPort); + await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); var buffer = Encoding.UTF8.GetBytes("Garbage"); client.Send(buffer, buffer.Length, SocketFlags.None); @@ -1189,7 +1190,7 @@ namespace MQTTnet.Tests try { - var receivedBytes = await client.ReceiveAsync(new ArraySegment(new byte[10]), SocketFlags.Partial); + var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment(new byte[10]), SocketFlags.Partial); if (receivedBytes == 0) { return;