@@ -72,11 +72,7 @@ namespace MQTTnet.Implementations | |||||
// Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 | // Workaround for: workaround for https://github.com/dotnet/corefx/issues/24430 | ||||
using (cancellationToken.Register(() => socket.Dispose())) | using (cancellationToken.Register(() => socket.Dispose())) | ||||
{ | { | ||||
#if NET452 || NET461 | |||||
await Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, _options.Server, _options.GetPort(), null).ConfigureAwait(false); | |||||
#else | |||||
await socket.ConnectAsync(_options.Server, _options.GetPort()).ConfigureAwait(false); | |||||
#endif | |||||
await PlatformAbstractionLayer.ConnectAsync(socket, _options.Server, _options.GetPort()).ConfigureAwait(false); | |||||
} | } | ||||
var networkStream = new NetworkStream(socket, true); | var networkStream = new NetworkStream(socket, true); | ||||
@@ -117,6 +113,10 @@ namespace MQTTnet.Implementations | |||||
return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | return await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | ||||
} | } | ||||
} | } | ||||
catch (ObjectDisposedException) | |||||
{ | |||||
return 0; | |||||
} | |||||
catch (IOException exception) | catch (IOException exception) | ||||
{ | { | ||||
if (exception.InnerException is SocketException socketException) | if (exception.InnerException is SocketException socketException) | ||||
@@ -143,6 +143,10 @@ namespace MQTTnet.Implementations | |||||
await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | await _stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); | ||||
} | } | ||||
} | } | ||||
catch (ObjectDisposedException) | |||||
{ | |||||
return; | |||||
} | |||||
catch (IOException exception) | catch (IOException exception) | ||||
{ | { | ||||
if (exception.InnerException is SocketException socketException) | if (exception.InnerException is SocketException socketException) | ||||
@@ -107,12 +107,7 @@ namespace MQTTnet.Implementations | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
#if NET452 || NET461 | |||||
var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); | |||||
#else | |||||
var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); | |||||
#endif | |||||
var clientSocket = await PlatformAbstractionLayer.AcceptAsync(_socket).ConfigureAwait(false); | |||||
if (clientSocket == null) | if (clientSocket == null) | ||||
{ | { | ||||
continue; | continue; | ||||
@@ -0,0 +1,92 @@ | |||||
using System; | |||||
using System.Net; | |||||
using System.Net.Sockets; | |||||
using System.Threading.Tasks; | |||||
namespace MQTTnet.Implementations | |||||
{ | |||||
public static class PlatformAbstractionLayer | |||||
{ | |||||
public static async Task<Socket> AcceptAsync(Socket socket) | |||||
{ | |||||
#if NET452 || NET461 | |||||
try | |||||
{ | |||||
return await Task.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null).ConfigureAwait(false); | |||||
} | |||||
catch (ObjectDisposedException) | |||||
{ | |||||
return null; | |||||
} | |||||
#else | |||||
return await socket.AcceptAsync().ConfigureAwait(false); | |||||
#endif | |||||
} | |||||
public static Task ConnectAsync(Socket socket, IPAddress ip, int port) | |||||
{ | |||||
#if NET452 || NET461 | |||||
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, ip, port, null); | |||||
#else | |||||
return socket.ConnectAsync(ip, port); | |||||
#endif | |||||
} | |||||
public static Task ConnectAsync(Socket socket, string host, int port) | |||||
{ | |||||
#if NET452 || NET461 | |||||
return Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, host, port, null); | |||||
#else | |||||
return socket.ConnectAsync(host, port); | |||||
#endif | |||||
} | |||||
#if NET452 || NET461 | |||||
public class SocketWrapper | |||||
{ | |||||
private readonly Socket _socket; | |||||
private readonly ArraySegment<byte> _buffer; | |||||
private readonly SocketFlags _socketFlags; | |||||
public SocketWrapper(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||||
{ | |||||
_socket = socket; | |||||
_buffer = buffer; | |||||
_socketFlags = socketFlags; | |||||
} | |||||
public static IAsyncResult BeginSend(AsyncCallback callback, object state) | |||||
{ | |||||
var real = (SocketWrapper)state; | |||||
return real._socket.BeginSend(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); | |||||
} | |||||
public static IAsyncResult BeginReceive(AsyncCallback callback, object state) | |||||
{ | |||||
var real = (SocketWrapper)state; | |||||
return real._socket.BeginReceive(real._buffer.Array, real._buffer.Offset, real._buffer.Count, real._socketFlags, callback, state); | |||||
} | |||||
} | |||||
#endif | |||||
public static Task SendAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||||
{ | |||||
#if NET452 || NET461 | |||||
return Task.Factory.FromAsync(SocketWrapper.BeginSend, socket.EndSend, new SocketWrapper(socket, buffer, socketFlags)); | |||||
#else | |||||
return socket.SendAsync(buffer, socketFlags); | |||||
#endif | |||||
} | |||||
public static Task<int> ReceiveAsync(Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags) | |||||
{ | |||||
#if NET452 || NET461 | |||||
return Task.Factory.FromAsync(SocketWrapper.BeginReceive, socket.EndReceive, new SocketWrapper(socket, buffer, socketFlags)); | |||||
#else | |||||
return socket.ReceiveAsync(buffer, socketFlags); | |||||
#endif | |||||
} | |||||
} | |||||
} |
@@ -1,7 +1,7 @@ | |||||
<Project Sdk="Microsoft.NET.Sdk"> | <Project Sdk="Microsoft.NET.Sdk"> | ||||
<PropertyGroup> | <PropertyGroup> | ||||
<TargetFramework>netcoreapp2.2</TargetFramework> | |||||
<TargetFrameworks>netcoreapp2.2;net461</TargetFrameworks> | |||||
<IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -39,7 +39,7 @@ namespace MQTTnet.Tests.Mockups | |||||
case MqttClientOptionsBuilder builder: | case MqttClientOptionsBuilder builder: | ||||
{ | { | ||||
var existingClientId = builder.Build().ClientId; | var existingClientId = builder.Build().ClientId; | ||||
if (!existingClientId.StartsWith(TestContext.TestName)) | |||||
if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) | |||||
{ | { | ||||
builder.WithClientId(TestContext.TestName + existingClientId); | builder.WithClientId(TestContext.TestName + existingClientId); | ||||
} | } | ||||
@@ -48,7 +48,7 @@ namespace MQTTnet.Tests.Mockups | |||||
case MqttClientOptions op: | case MqttClientOptions op: | ||||
{ | { | ||||
var existingClientId = op.ClientId; | var existingClientId = op.ClientId; | ||||
if (!existingClientId.StartsWith(TestContext.TestName)) | |||||
if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) | |||||
{ | { | ||||
op.ClientId = TestContext.TestName + existingClientId; | op.ClientId = TestContext.TestName + existingClientId; | ||||
} | } | ||||
@@ -237,6 +237,36 @@ namespace MQTTnet.Tests | |||||
} | } | ||||
} | } | ||||
[TestMethod] | |||||
public async Task ConnectTimeout_Throws_Exception() | |||||
{ | |||||
var factory = new MqttFactory(); | |||||
using (var client = factory.CreateMqttClient()) | |||||
{ | |||||
bool disconnectHandlerCalled = false; | |||||
try | |||||
{ | |||||
client.DisconnectedHandler = new MqttClientDisconnectedHandlerDelegate(args => | |||||
{ | |||||
disconnectHandlerCalled = true; | |||||
}); | |||||
await client.ConnectAsync(new MqttClientOptionsBuilder().WithTcpServer("1.2.3.4").Build()); | |||||
Assert.Fail("Must fail!"); | |||||
} | |||||
catch (Exception exception) | |||||
{ | |||||
Assert.IsNotNull(exception); | |||||
Assert.IsInstanceOfType(exception, typeof(MqttCommunicationException)); | |||||
//Assert.IsInstanceOfType(exception.InnerException, typeof(SocketException)); | |||||
} | |||||
await Task.Delay(100); // disconnected handler is called async | |||||
Assert.IsTrue(disconnectHandlerCalled); | |||||
} | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public async Task Fire_Disconnected_Event_On_Server_Shutdown() | public async Task Fire_Disconnected_Event_On_Server_Shutdown() | ||||
{ | { | ||||
@@ -28,14 +28,14 @@ namespace MQTTnet.Tests | |||||
{ | { | ||||
while (!ct.IsCancellationRequested) | while (!ct.IsCancellationRequested) | ||||
{ | { | ||||
var client = await serverSocket.AcceptAsync(); | |||||
var client = await PlatformAbstractionLayer.AcceptAsync(serverSocket); | |||||
var data = new byte[] { 128 }; | var data = new byte[] { 128 }; | ||||
await client.SendAsync(new ArraySegment<byte>(data), SocketFlags.None); | |||||
await PlatformAbstractionLayer.SendAsync(client, new ArraySegment<byte>(data), SocketFlags.None); | |||||
} | } | ||||
}, ct.Token); | }, ct.Token); | ||||
var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | ||||
await clientSocket.ConnectAsync(IPAddress.Loopback, 50001); | |||||
await PlatformAbstractionLayer.ConnectAsync(clientSocket, IPAddress.Loopback, 50001); | |||||
await Task.Delay(100, ct.Token); | await Task.Delay(100, ct.Token); | ||||
@@ -13,6 +13,7 @@ using MQTTnet.Client.Disconnecting; | |||||
using MQTTnet.Client.Options; | using MQTTnet.Client.Options; | ||||
using MQTTnet.Client.Receiving; | using MQTTnet.Client.Receiving; | ||||
using MQTTnet.Client.Subscribing; | using MQTTnet.Client.Subscribing; | ||||
using MQTTnet.Implementations; | |||||
using MQTTnet.Protocol; | using MQTTnet.Protocol; | ||||
using MQTTnet.Server; | using MQTTnet.Server; | ||||
using MQTTnet.Tests.Mockups; | using MQTTnet.Tests.Mockups; | ||||
@@ -1140,7 +1141,7 @@ namespace MQTTnet.Tests | |||||
await testEnvironment.ConnectClientAsync(); | await testEnvironment.ConnectClientAsync(); | ||||
} | } | ||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public async Task Close_Idle_Connection() | public async Task Close_Idle_Connection() | ||||
{ | { | ||||
@@ -1149,14 +1150,14 @@ namespace MQTTnet.Tests | |||||
await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); | await testEnvironment.StartServerAsync(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); | ||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | ||||
await client.ConnectAsync("localhost", testEnvironment.ServerPort); | |||||
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); | |||||
// Don't send anything. The server should close the connection. | // Don't send anything. The server should close the connection. | ||||
await Task.Delay(TimeSpan.FromSeconds(3)); | await Task.Delay(TimeSpan.FromSeconds(3)); | ||||
try | try | ||||
{ | { | ||||
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||||
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||||
if (receivedBytes == 0) | if (receivedBytes == 0) | ||||
{ | { | ||||
return; | return; | ||||
@@ -1180,7 +1181,7 @@ namespace MQTTnet.Tests | |||||
// Send an invalid packet and ensure that the server will close the connection and stay in a waiting state | // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state | ||||
// forever. This is security related. | // forever. This is security related. | ||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | ||||
await client.ConnectAsync("localhost", testEnvironment.ServerPort); | |||||
await PlatformAbstractionLayer.ConnectAsync(client, "localhost", testEnvironment.ServerPort); | |||||
var buffer = Encoding.UTF8.GetBytes("Garbage"); | var buffer = Encoding.UTF8.GetBytes("Garbage"); | ||||
client.Send(buffer, buffer.Length, SocketFlags.None); | client.Send(buffer, buffer.Length, SocketFlags.None); | ||||
@@ -1189,7 +1190,7 @@ namespace MQTTnet.Tests | |||||
try | try | ||||
{ | { | ||||
var receivedBytes = await client.ReceiveAsync(new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||||
var receivedBytes = await PlatformAbstractionLayer.ReceiveAsync(client, new ArraySegment<byte>(new byte[10]), SocketFlags.Partial); | |||||
if (receivedBytes == 0) | if (receivedBytes == 0) | ||||
{ | { | ||||
return; | return; | ||||