diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.HostEntry.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.HostEntry.cs index 7578452daa585..bfecbc412dbcd 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.HostEntry.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.HostEntry.cs @@ -33,7 +33,7 @@ internal unsafe struct HostEntry } [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetHostEntryForName")] - internal static extern unsafe int GetHostEntryForName(string address, HostEntry* entry); + internal static extern unsafe int GetHostEntryForName(string address, System.Net.Sockets.AddressFamily family, HostEntry* entry); [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FreeHostEntry")] internal static extern unsafe void FreeHostEntry(HostEntry* entry); diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.GetAddrInfoExW.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.GetAddrInfoExW.cs index c5c233c527513..427586629dc4a 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.GetAddrInfoExW.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.GetAddrInfoExW.cs @@ -11,6 +11,9 @@ internal static partial class Interop { internal static partial class Winsock { + internal const int WSA_INVALID_HANDLE = 6; + internal const int WSA_E_CANCELLED = 10111; + internal const string GetAddrInfoExCancelFunctionName = "GetAddrInfoExCancel"; internal const int NS_ALL = 0; @@ -28,6 +31,9 @@ internal static extern unsafe int GetAddrInfoExW( [In] delegate* unmanaged lpCompletionRoutine, [Out] IntPtr* lpNameHandle); + [DllImport(Libraries.Ws2_32, ExactSpelling = true)] + internal static extern unsafe int GetAddrInfoExCancel([In] IntPtr* lpHandle); + [DllImport(Libraries.Ws2_32, ExactSpelling = true)] internal static extern unsafe void FreeAddrInfoExW(AddressInfoEx* pAddrInfo); diff --git a/src/libraries/Native/Unix/System.Native/pal_networking.c b/src/libraries/Native/Unix/System.Native/pal_networking.c index 09ec260149f49..2640f3180c8a6 100644 --- a/src/libraries/Native/Unix/System.Native/pal_networking.c +++ b/src/libraries/Native/Unix/System.Native/pal_networking.c @@ -170,6 +170,80 @@ c_static_assert(offsetof(IOVector, Count) == offsetof(iovec, iov_len)); #define Min(left,right) (((left) < (right)) ? (left) : (right)) +static bool TryConvertAddressFamilyPlatformToPal(sa_family_t platformAddressFamily, int32_t* palAddressFamily) +{ + assert(palAddressFamily != NULL); + + switch (platformAddressFamily) + { + case AF_UNSPEC: + *palAddressFamily = AddressFamily_AF_UNSPEC; + return true; + + case AF_UNIX: + *palAddressFamily = AddressFamily_AF_UNIX; + return true; + + case AF_INET: + *palAddressFamily = AddressFamily_AF_INET; + return true; + + case AF_INET6: + *palAddressFamily = AddressFamily_AF_INET6; + return true; +#ifdef AF_PACKET + case AF_PACKET: + *palAddressFamily = AddressFamily_AF_PACKET; + return true; +#endif +#ifdef AF_CAN + case AF_CAN: + *palAddressFamily = AddressFamily_AF_CAN; + return true; +#endif + default: + *palAddressFamily = platformAddressFamily; + return false; + } +} + +static bool TryConvertAddressFamilyPalToPlatform(int32_t palAddressFamily, sa_family_t* platformAddressFamily) +{ + assert(platformAddressFamily != NULL); + + switch (palAddressFamily) + { + case AddressFamily_AF_UNSPEC: + *platformAddressFamily = AF_UNSPEC; + return true; + + case AddressFamily_AF_UNIX: + *platformAddressFamily = AF_UNIX; + return true; + + case AddressFamily_AF_INET: + *platformAddressFamily = AF_INET; + return true; + + case AddressFamily_AF_INET6: + *platformAddressFamily = AF_INET6; + return true; +#ifdef AF_PACKET + case AddressFamily_AF_PACKET: + *platformAddressFamily = AF_PACKET; + return true; +#endif +#ifdef AF_CAN + case AddressFamily_AF_CAN: + *platformAddressFamily = AF_CAN; + return true; +#endif + default: + *platformAddressFamily = (sa_family_t)palAddressFamily; + return false; + } +} + static void ConvertByteArrayToIn6Addr(struct in6_addr* addr, const uint8_t* buffer, int32_t bufferLength) { assert(bufferLength == NUM_BYTES_IN_IPV6_ADDRESS); @@ -261,7 +335,7 @@ static int32_t CopySockAddrToIPAddress(sockaddr* addr, sa_family_t family, IPAdd return -1; } -int32_t SystemNative_GetHostEntryForName(const uint8_t* address, HostEntry* entry) +int32_t SystemNative_GetHostEntryForName(const uint8_t* address, int32_t addressFamily, HostEntry* entry) { if (address == NULL || entry == NULL) { @@ -275,11 +349,16 @@ int32_t SystemNative_GetHostEntryForName(const uint8_t* address, HostEntry* entr struct ifaddrs* addrs = NULL; #endif - // Get all address families and the canonical name + sa_family_t platformFamily; + if (!TryConvertAddressFamilyPalToPlatform(addressFamily, &platformFamily)) + { + return GetAddrInfoErrorFlags_EAI_FAMILY; + } + struct addrinfo hint; memset(&hint, 0, sizeof(struct addrinfo)); - hint.ai_family = AF_UNSPEC; hint.ai_flags = AI_CANONNAME; + hint.ai_family = platformFamily; int result = getaddrinfo((const char*)address, NULL, &hint, &info); if (result != 0) @@ -593,80 +672,6 @@ int32_t SystemNative_GetIPSocketAddressSizes(int32_t* ipv4SocketAddressSize, int return Error_SUCCESS; } -static bool TryConvertAddressFamilyPlatformToPal(sa_family_t platformAddressFamily, int32_t* palAddressFamily) -{ - assert(palAddressFamily != NULL); - - switch (platformAddressFamily) - { - case AF_UNSPEC: - *palAddressFamily = AddressFamily_AF_UNSPEC; - return true; - - case AF_UNIX: - *palAddressFamily = AddressFamily_AF_UNIX; - return true; - - case AF_INET: - *palAddressFamily = AddressFamily_AF_INET; - return true; - - case AF_INET6: - *palAddressFamily = AddressFamily_AF_INET6; - return true; -#ifdef AF_PACKET - case AF_PACKET: - *palAddressFamily = AddressFamily_AF_PACKET; - return true; -#endif -#ifdef AF_CAN - case AF_CAN: - *palAddressFamily = AddressFamily_AF_CAN; - return true; -#endif - default: - *palAddressFamily = platformAddressFamily; - return false; - } -} - -static bool TryConvertAddressFamilyPalToPlatform(int32_t palAddressFamily, sa_family_t* platformAddressFamily) -{ - assert(platformAddressFamily != NULL); - - switch (palAddressFamily) - { - case AddressFamily_AF_UNSPEC: - *platformAddressFamily = AF_UNSPEC; - return true; - - case AddressFamily_AF_UNIX: - *platformAddressFamily = AF_UNIX; - return true; - - case AddressFamily_AF_INET: - *platformAddressFamily = AF_INET; - return true; - - case AddressFamily_AF_INET6: - *platformAddressFamily = AF_INET6; - return true; -#ifdef AF_PACKET - case AddressFamily_AF_PACKET: - *platformAddressFamily = AF_PACKET; - return true; -#endif -#ifdef AF_CAN - case AddressFamily_AF_CAN: - *platformAddressFamily = AF_CAN; - return true; -#endif - default: - *platformAddressFamily = (sa_family_t)palAddressFamily; - return false; - } -} - int32_t SystemNative_GetAddressFamily(const uint8_t* socketAddress, int32_t socketAddressLen, int32_t* addressFamily) { if (socketAddress == NULL || addressFamily == NULL || socketAddressLen < 0) diff --git a/src/libraries/Native/Unix/System.Native/pal_networking.h b/src/libraries/Native/Unix/System.Native/pal_networking.h index 6c2422b3cba37..bbb0bc0785cce 100644 --- a/src/libraries/Native/Unix/System.Native/pal_networking.h +++ b/src/libraries/Native/Unix/System.Native/pal_networking.h @@ -301,7 +301,7 @@ typedef struct uint32_t Padding; // Pad out to 8-byte alignment } SocketEvent; -PALEXPORT int32_t SystemNative_GetHostEntryForName(const uint8_t* address, HostEntry* entry); +PALEXPORT int32_t SystemNative_GetHostEntryForName(const uint8_t* address, int32_t addressFamily, HostEntry* entry); PALEXPORT void SystemNative_FreeHostEntry(HostEntry* entry); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/AuthenticationHelper.NtAuth.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/AuthenticationHelper.NtAuth.cs index 3b2e569445105..e4218588f8073 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/AuthenticationHelper.NtAuth.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/AuthenticationHelper.NtAuth.cs @@ -110,7 +110,7 @@ private static async Task SendWithNtAuthAsync(HttpRequestMe } else { - IPHostEntry result = await Dns.GetHostEntryAsync(authUri.IdnHost).ConfigureAwait(false); + IPHostEntry result = await Dns.GetHostEntryAsync(authUri.IdnHost, cancellationToken).ConfigureAwait(false); hostName = result.HostName; } diff --git a/src/libraries/System.Net.NameResolution/ref/System.Net.NameResolution.cs b/src/libraries/System.Net.NameResolution/ref/System.Net.NameResolution.cs index 0ea53e719c865..b92fd42c3aab5 100644 --- a/src/libraries/System.Net.NameResolution/ref/System.Net.NameResolution.cs +++ b/src/libraries/System.Net.NameResolution/ref/System.Net.NameResolution.cs @@ -22,7 +22,10 @@ public static partial class Dns [System.ObsoleteAttribute("EndResolve is obsoleted for this type, please use EndGetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] public static System.Net.IPHostEntry EndResolve(System.IAsyncResult asyncResult) { throw null; } public static System.Net.IPAddress[] GetHostAddresses(string hostNameOrAddress) { throw null; } + public static System.Net.IPAddress[] GetHostAddresses(string hostNameOrAddress, System.Net.Sockets.AddressFamily family) { throw null; } public static System.Threading.Tasks.Task GetHostAddressesAsync(string hostNameOrAddress) { throw null; } + public static System.Threading.Tasks.Task GetHostAddressesAsync(string hostNameOrAddress, System.Net.Sockets.AddressFamily family, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.Task GetHostAddressesAsync(string hostNameOrAddress, System.Threading.CancellationToken cancellationToken) { throw null; } [System.ObsoleteAttribute("GetHostByAddress is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] public static System.Net.IPHostEntry GetHostByAddress(System.Net.IPAddress address) { throw null; } [System.ObsoleteAttribute("GetHostByAddress is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] @@ -31,8 +34,11 @@ public static partial class Dns public static System.Net.IPHostEntry GetHostByName(string hostName) { throw null; } public static System.Net.IPHostEntry GetHostEntry(System.Net.IPAddress address) { throw null; } public static System.Net.IPHostEntry GetHostEntry(string hostNameOrAddress) { throw null; } + public static System.Net.IPHostEntry GetHostEntry(string hostNameOrAddress, System.Net.Sockets.AddressFamily family) { throw null; } public static System.Threading.Tasks.Task GetHostEntryAsync(System.Net.IPAddress address) { throw null; } public static System.Threading.Tasks.Task GetHostEntryAsync(string hostNameOrAddress) { throw null; } + public static System.Threading.Tasks.Task GetHostEntryAsync(string hostNameOrAddress, System.Net.Sockets.AddressFamily family, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.Task GetHostEntryAsync(string hostNameOrAddress, System.Threading.CancellationToken cancellationToken) { throw null; } public static string GetHostName() { throw null; } [System.ObsoleteAttribute("Resolve is obsoleted for this type, please use GetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] public static System.Net.IPHostEntry Resolve(string hostName) { throw null; } diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs index 942db43c71bd8..d7675c78dc040 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs @@ -50,13 +50,24 @@ public static IPHostEntry GetHostEntry(IPAddress address) throw new ArgumentException(SR.Format(SR.net_invalid_ip_addr, nameof(address))); } - IPHostEntry ipHostEntry = GetHostEntryCore(address); + IPHostEntry ipHostEntry = GetHostEntryCore(address, AddressFamily.Unspecified); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(address, $"{ipHostEntry} with {ipHostEntry.AddressList.Length} entries"); return ipHostEntry; } - public static IPHostEntry GetHostEntry(string hostNameOrAddress) + public static IPHostEntry GetHostEntry(string hostNameOrAddress) => + GetHostEntry(hostNameOrAddress, AddressFamily.Unspecified); + + /// + /// Resolves a host name or IP address to an instance. + /// + /// The host name or IP address to resolve. + /// The address family for which IPs should be retrieved. If , retrieve all IPs regardless of address family. + /// + /// An instance that contains the address information about the host specified in . + /// + public static IPHostEntry GetHostEntry(string hostNameOrAddress, AddressFamily family) { if (hostNameOrAddress is null) { @@ -73,29 +84,54 @@ public static IPHostEntry GetHostEntry(string hostNameOrAddress) throw new ArgumentException(SR.Format(SR.net_invalid_ip_addr, nameof(hostNameOrAddress))); } - ipHostEntry = GetHostEntryCore(address); + ipHostEntry = GetHostEntryCore(address, family); } else { - ipHostEntry = GetHostEntryCore(hostNameOrAddress); + ipHostEntry = GetHostEntryCore(hostNameOrAddress, family); } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(hostNameOrAddress, $"{ipHostEntry} with {ipHostEntry.AddressList.Length} entries"); return ipHostEntry; } - public static Task GetHostEntryAsync(string hostNameOrAddress) + public static Task GetHostEntryAsync(string hostNameOrAddress) => + GetHostEntryAsync(hostNameOrAddress, AddressFamily.Unspecified, CancellationToken.None); + + /// + /// Resolves a host name or IP address to an instance as an asynchronous operation. + /// + /// The host name or IP address to resolve. + /// A cancellation token that can be used to signal the asynchronous operation should be canceled. + /// + /// The task object representing the asynchronous operation. The property on the task object returns + /// an instance that contains the address information about the host specified in . + /// + public static Task GetHostEntryAsync(string hostNameOrAddress, CancellationToken cancellationToken) => + GetHostEntryAsync(hostNameOrAddress, AddressFamily.Unspecified, cancellationToken); + + /// + /// Resolves a host name or IP address to an instance as an asynchronous operation. + /// + /// The host name or IP address to resolve. + /// The address family for which IPs should be retrieved. If , retrieve all IPs regardless of address family. + /// A cancellation token that can be used to signal the asynchronous operation should be canceled. + /// + /// The task object representing the asynchronous operation. The property on the task object returns + /// an instance that contains the address information about the host specified in . + /// + public static Task GetHostEntryAsync(string hostNameOrAddress, AddressFamily family, CancellationToken cancellationToken = default) { if (NetEventSource.Log.IsEnabled()) { - Task t = GetHostEntryCoreAsync(hostNameOrAddress, justReturnParsedIp: false, throwOnIIPAny: true); + Task t = GetHostEntryCoreAsync(hostNameOrAddress, justReturnParsedIp: false, throwOnIIPAny: true, family, cancellationToken); t.ContinueWith((t, s) => NetEventSource.Info((string)s!, $"{t.Result} with {((IPHostEntry)t.Result).AddressList.Length} entries"), hostNameOrAddress, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnRanToCompletion, TaskScheduler.Default); return t; } else { - return GetHostEntryCoreAsync(hostNameOrAddress, justReturnParsedIp: false, throwOnIIPAny: true); + return GetHostEntryCoreAsync(hostNameOrAddress, justReturnParsedIp: false, throwOnIIPAny: true, family, cancellationToken); } } @@ -113,7 +149,7 @@ public static Task GetHostEntryAsync(IPAddress address) } return RunAsync(s => { - IPHostEntry ipHostEntry = GetHostEntryCore((IPAddress)s); + IPHostEntry ipHostEntry = GetHostEntryCore((IPAddress)s, AddressFamily.Unspecified); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info((IPAddress)s, $"{ipHostEntry} with {ipHostEntry.AddressList.Length} entries"); return ipHostEntry; }, address); @@ -129,6 +165,17 @@ public static IPHostEntry EndGetHostEntry(IAsyncResult asyncResult) => TaskToApm.End(asyncResult ?? throw new ArgumentNullException(nameof(asyncResult))); public static IPAddress[] GetHostAddresses(string hostNameOrAddress) + => GetHostAddresses(hostNameOrAddress, AddressFamily.Unspecified); + + /// + /// Returns the Internet Protocol (IP) addresses for the specified host. + /// + /// The host name or IP address to resolve. + /// The address family for which IPs should be retrieved. If , retrieve all IPs regardless of address family. + /// + /// An array of type that holds the IP addresses for the host that is specified by the parameter. + /// + public static IPAddress[] GetHostAddresses(string hostNameOrAddress, AddressFamily family) { if (hostNameOrAddress is null) { @@ -145,11 +192,11 @@ public static IPAddress[] GetHostAddresses(string hostNameOrAddress) throw new ArgumentException(SR.Format(SR.net_invalid_ip_addr, nameof(hostNameOrAddress))); } - addresses = new IPAddress[] { address }; + addresses = (family == AddressFamily.Unspecified || address.AddressFamily == family) ? new IPAddress[] { address } : Array.Empty(); } else { - addresses = GetHostAddressesCore(hostNameOrAddress); + addresses = GetHostAddressesCore(hostNameOrAddress, family); } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(hostNameOrAddress, addresses); @@ -157,7 +204,32 @@ public static IPAddress[] GetHostAddresses(string hostNameOrAddress) } public static Task GetHostAddressesAsync(string hostNameOrAddress) => - (Task)GetHostEntryOrAddressesCoreAsync(hostNameOrAddress, justReturnParsedIp: true, throwOnIIPAny: true, justAddresses: true); + (Task)GetHostEntryOrAddressesCoreAsync(hostNameOrAddress, justReturnParsedIp: true, throwOnIIPAny: true, justAddresses: true, AddressFamily.Unspecified, CancellationToken.None); + + /// + /// Returns the Internet Protocol (IP) addresses for the specified host as an asynchronous operation. + /// + /// The host name or IP address to resolve. + /// A cancellation token that can be used to signal the asynchronous operation should be canceled. + /// + /// The task object representing the asynchronous operation. The property on the task object returns an array of + /// type that holds the IP addresses for the host that is specified by the parameter. + /// + public static Task GetHostAddressesAsync(string hostNameOrAddress, CancellationToken cancellationToken) => + (Task)GetHostEntryOrAddressesCoreAsync(hostNameOrAddress, justReturnParsedIp: true, throwOnIIPAny: true, justAddresses: true, AddressFamily.Unspecified, cancellationToken); + + /// + /// Returns the Internet Protocol (IP) addresses for the specified host as an asynchronous operation. + /// + /// The host name or IP address to resolve. + /// The address family for which IPs should be retrieved. If , retrieve all IPs regardless of address family. + /// A cancellation token that can be used to signal the asynchronous operation should be canceled. + /// + /// The task object representing the asynchronous operation. The property on the task object returns an array of + /// type that holds the IP addresses for the host that is specified by the parameter. + /// + public static Task GetHostAddressesAsync(string hostNameOrAddress, AddressFamily family, CancellationToken cancellationToken = default) => + (Task)GetHostEntryOrAddressesCoreAsync(hostNameOrAddress, justReturnParsedIp: true, throwOnIIPAny: true, justAddresses: true, family, cancellationToken); public static IAsyncResult BeginGetHostAddresses(string hostNameOrAddress, AsyncCallback? requestCallback, object? state) => TaskToApm.Begin(GetHostAddressesAsync(hostNameOrAddress), requestCallback, state); @@ -178,12 +250,12 @@ public static IPHostEntry GetHostByName(string hostName) return CreateHostEntryForAddress(address); } - return GetHostEntryCore(hostName); + return GetHostEntryCore(hostName, AddressFamily.Unspecified); } [Obsolete("BeginGetHostByName is obsoleted for this type, please use BeginGetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] public static IAsyncResult BeginGetHostByName(string hostName, AsyncCallback? requestCallback, object? stateObject) => - TaskToApm.Begin(GetHostEntryCoreAsync(hostName, justReturnParsedIp: true, throwOnIIPAny: true), requestCallback, stateObject); + TaskToApm.Begin(GetHostEntryCoreAsync(hostName, justReturnParsedIp: true, throwOnIIPAny: true, AddressFamily.Unspecified, CancellationToken.None), requestCallback, stateObject); [Obsolete("EndGetHostByName is obsoleted for this type, please use EndGetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] public static IPHostEntry EndGetHostByName(IAsyncResult asyncResult) => @@ -197,7 +269,7 @@ public static IPHostEntry GetHostByAddress(string address) throw new ArgumentNullException(nameof(address)); } - IPHostEntry ipHostEntry = GetHostEntryCore(IPAddress.Parse(address)); + IPHostEntry ipHostEntry = GetHostEntryCore(IPAddress.Parse(address), AddressFamily.Unspecified); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(address, ipHostEntry); return ipHostEntry; @@ -211,7 +283,7 @@ public static IPHostEntry GetHostByAddress(IPAddress address) throw new ArgumentNullException(nameof(address)); } - IPHostEntry ipHostEntry = GetHostEntryCore(address); + IPHostEntry ipHostEntry = GetHostEntryCore(address, AddressFamily.Unspecified); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(address, ipHostEntry); return ipHostEntry; @@ -232,7 +304,7 @@ public static IPHostEntry Resolve(string hostName) { try { - ipHostEntry = GetHostEntryCore(address); + ipHostEntry = GetHostEntryCore(address, AddressFamily.Unspecified); } catch (SocketException ex) { @@ -242,7 +314,7 @@ public static IPHostEntry Resolve(string hostName) } else { - ipHostEntry = GetHostEntryCore(hostName); + ipHostEntry = GetHostEntryCore(hostName, AddressFamily.Unspecified); } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(hostName, ipHostEntry); @@ -251,7 +323,7 @@ public static IPHostEntry Resolve(string hostName) [Obsolete("BeginResolve is obsoleted for this type, please use BeginGetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] public static IAsyncResult BeginResolve(string hostName, AsyncCallback? requestCallback, object? stateObject) => - TaskToApm.Begin(GetHostEntryCoreAsync(hostName, justReturnParsedIp: false, throwOnIIPAny: false), requestCallback, stateObject); + TaskToApm.Begin(GetHostEntryCoreAsync(hostName, justReturnParsedIp: false, throwOnIIPAny: false, AddressFamily.Unspecified, CancellationToken.None), requestCallback, stateObject); [Obsolete("EndResolve is obsoleted for this type, please use EndGetHostEntry instead. https://go.microsoft.com/fwlink/?linkid=14202")] public static IPHostEntry EndResolve(IAsyncResult asyncResult) @@ -264,10 +336,17 @@ public static IPHostEntry EndResolve(IAsyncResult asyncResult) } catch (SocketException ex) { - IPAddress? address = asyncResult switch + object? asyncState = asyncResult switch { - Task t => t.AsyncState as IPAddress, - TaskToApm.TaskAsyncResult twar => twar._task.AsyncState as IPAddress, + Task t => t.AsyncState, + TaskToApm.TaskAsyncResult twar => twar._task.AsyncState, + _ => null + }; + + IPAddress? address = asyncState switch + { + IPAddress a => a, + Tuple t => t.Item1, _ => null }; @@ -275,6 +354,7 @@ public static IPHostEntry EndResolve(IAsyncResult asyncResult) throw; // BeginResolve was called with a HostName, not an IPAddress if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(null, ex); + ipHostEntry = CreateHostEntryForAddress(address); } @@ -282,13 +362,13 @@ public static IPHostEntry EndResolve(IAsyncResult asyncResult) return ipHostEntry; } - private static IPHostEntry GetHostEntryCore(string hostName) => - (IPHostEntry)GetHostEntryOrAddressesCore(hostName, justAddresses: false); + private static IPHostEntry GetHostEntryCore(string hostName, AddressFamily addressFamily) => + (IPHostEntry)GetHostEntryOrAddressesCore(hostName, justAddresses: false, addressFamily); - private static IPAddress[] GetHostAddressesCore(string hostName) => - (IPAddress[])GetHostEntryOrAddressesCore(hostName, justAddresses: true); + private static IPAddress[] GetHostAddressesCore(string hostName, AddressFamily addressFamily) => + (IPAddress[])GetHostEntryOrAddressesCore(hostName, justAddresses: true, addressFamily); - private static object GetHostEntryOrAddressesCore(string hostName, bool justAddresses) + private static object GetHostEntryOrAddressesCore(string hostName, bool justAddresses, AddressFamily addressFamily) { ValidateHostName(hostName); @@ -297,7 +377,7 @@ private static object GetHostEntryOrAddressesCore(string hostName, bool justAddr object result; try { - SocketError errorCode = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses, out string? newHostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); + SocketError errorCode = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses, addressFamily, out string? newHostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); if (errorCode != SocketError.Success) { @@ -326,14 +406,14 @@ private static object GetHostEntryOrAddressesCore(string hostName, bool justAddr return result; } - private static IPHostEntry GetHostEntryCore(IPAddress address) => - (IPHostEntry)GetHostEntryOrAddressesCore(address, justAddresses: false); + private static IPHostEntry GetHostEntryCore(IPAddress address, AddressFamily addressFamily) => + (IPHostEntry)GetHostEntryOrAddressesCore(address, justAddresses: false, addressFamily); - private static IPAddress[] GetHostAddressesCore(IPAddress address) => - (IPAddress[])GetHostEntryOrAddressesCore(address, justAddresses: true); + private static IPAddress[] GetHostAddressesCore(IPAddress address, AddressFamily addressFamily) => + (IPAddress[])GetHostEntryOrAddressesCore(address, justAddresses: true, addressFamily); // Does internal IPAddress reverse and then forward lookups (for Legacy and current public methods). - private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAddresses) + private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAddresses, AddressFamily addressFamily) { // Try to get the data for the host from its address. // We need to call getnameinfo first, because getaddrinfo w/ the ipaddress string @@ -371,7 +451,7 @@ private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAd object result; try { - errorCode = NameResolutionPal.TryGetAddrInfo(name, justAddresses, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); + errorCode = NameResolutionPal.TryGetAddrInfo(name, justAddresses, addressFamily, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); if (errorCode != SocketError.Success) { @@ -406,17 +486,26 @@ private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAd return result; } - private static Task GetHostEntryCoreAsync(string hostName, bool justReturnParsedIp, bool throwOnIIPAny) => - (Task)GetHostEntryOrAddressesCoreAsync(hostName, justReturnParsedIp, throwOnIIPAny, justAddresses: false); + private static Task GetHostEntryCoreAsync(string hostName, bool justReturnParsedIp, bool throwOnIIPAny, AddressFamily family, CancellationToken cancellationToken) => + (Task)GetHostEntryOrAddressesCoreAsync(hostName, justReturnParsedIp, throwOnIIPAny, justAddresses: false, family, cancellationToken); // If hostName is an IPString and justReturnParsedIP==true then no reverse lookup will be attempted, but the original address is returned. - private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justReturnParsedIp, bool throwOnIIPAny, bool justAddresses) + private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justReturnParsedIp, bool throwOnIIPAny, bool justAddresses, AddressFamily family, CancellationToken cancellationToken) { if (hostName is null) { throw new ArgumentNullException(nameof(hostName)); } + if (cancellationToken.IsCancellationRequested) + { + return justAddresses ? (Task) + Task.FromCanceled(cancellationToken) : + Task.FromCanceled(cancellationToken); + } + + object asyncState; + // See if it's an IP Address. if (IPAddress.TryParse(hostName, out IPAddress? ipAddress)) { @@ -429,39 +518,62 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR if (justReturnParsedIp) { return justAddresses ? (Task) - Task.FromResult(new[] { ipAddress }) : + Task.FromResult(family == AddressFamily.Unspecified || ipAddress.AddressFamily == family ? new[] { ipAddress } : Array.Empty()) : Task.FromResult(CreateHostEntryForAddress(ipAddress)); } - return justAddresses ? (Task) - RunAsync(s => GetHostAddressesCore((IPAddress)s), ipAddress) : - RunAsync(s => GetHostEntryCore((IPAddress)s), ipAddress); + asyncState = family == AddressFamily.Unspecified ? (object)ipAddress : Tuple.Create(ipAddress, family); } - - // If the OS supports it and 'hostName' is not an IP Address, resolve the name asynchronously - // instead of calling the synchronous version in the ThreadPool. - if (NameResolutionPal.SupportsGetAddrInfoAsync && ipAddress is null) + else if (NameResolutionPal.SupportsGetAddrInfoAsync) { +#pragma warning disable CS0162 // Unreachable code detected -- SupportsGetAddrInfoAsync is a constant on *nix. + + // If the OS supports it and 'hostName' is not an IP Address, resolve the name asynchronously + // instead of calling the synchronous version in the ThreadPool. + ValidateHostName(hostName); if (NameResolutionTelemetry.Log.IsEnabled()) { return justAddresses - ? (Task)GetAddrInfoWithTelemetryAsync(hostName, justAddresses) - : (Task)GetAddrInfoWithTelemetryAsync(hostName, justAddresses); + ? (Task)GetAddrInfoWithTelemetryAsync(hostName, justAddresses, family, cancellationToken) + : (Task)GetAddrInfoWithTelemetryAsync(hostName, justAddresses, family, cancellationToken); } else { - return NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses); + return NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses, family, cancellationToken); } } + else + { + asyncState = family == AddressFamily.Unspecified ? (object)hostName : Tuple.Create(hostName, family); + } - return justAddresses ? (Task) - RunAsync(s => GetHostAddressesCore((string)s), hostName) : - RunAsync(s => GetHostEntryCore((string)s), hostName); + if (justAddresses) + { + return RunAsync(s => s switch + { + string h => GetHostAddressesCore(h, AddressFamily.Unspecified), + Tuple t => GetHostAddressesCore(t.Item1, t.Item2), + IPAddress a => GetHostAddressesCore(a, AddressFamily.Unspecified), + Tuple t => GetHostAddressesCore(t.Item1, t.Item2), + _ => null + }, asyncState); + } + else + { + return RunAsync(s => s switch + { + string h => GetHostEntryCore(h, AddressFamily.Unspecified), + Tuple t => GetHostEntryCore(t.Item1, t.Item2), + IPAddress a => GetHostEntryCore(a, AddressFamily.Unspecified), + Tuple t => GetHostEntryCore(t.Item1, t.Item2), + _ => null + }, asyncState); + } } - private static async Task GetAddrInfoWithTelemetryAsync(string hostName, bool justAddresses) + private static async Task GetAddrInfoWithTelemetryAsync(string hostName, bool justAddresses, AddressFamily addressFamily, CancellationToken cancellationToken) where T : class { ValueStopwatch stopwatch = NameResolutionTelemetry.Log.BeforeResolution(hostName); @@ -469,7 +581,7 @@ private static async Task GetAddrInfoWithTelemetryAsync(string hostName, b T? result = null; try { - result = await ((Task)NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses)).ConfigureAwait(false); + result = await ((Task)NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses, addressFamily, cancellationToken)).ConfigureAwait(false); return result; } finally diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs index add10313ccd44..cc9037b032a15 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs @@ -7,6 +7,7 @@ using System.Net.Sockets; using System.Runtime.InteropServices; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace System.Net @@ -15,7 +16,7 @@ internal static partial class NameResolutionPal { public const bool SupportsGetAddrInfoAsync = false; - internal static Task GetAddrInfoAsync(string hostName, bool justAddresses) => + internal static Task GetAddrInfoAsync(string hostName, bool justAddresses, AddressFamily family, CancellationToken cancellationToken) => throw new NotSupportedException(); private static SocketError GetSocketErrorForNativeError(int error) @@ -116,7 +117,7 @@ private static unsafe void ParseHostEntry(Interop.Sys.HostEntry hostEntry, bool } } - public static unsafe SocketError TryGetAddrInfo(string name, bool justAddresses, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode) + public static unsafe SocketError TryGetAddrInfo(string name, bool justAddresses, AddressFamily addressFamily, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode) { if (name == "") { @@ -125,7 +126,7 @@ public static unsafe SocketError TryGetAddrInfo(string name, bool justAddresses, } Interop.Sys.HostEntry entry; - int result = Interop.Sys.GetHostEntryForName(name, &entry); + int result = Interop.Sys.GetHostEntryForName(name, addressFamily, &entry); if (result != 0) { nativeErrorCode = result; diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs index 1de610f86e572..3cdf343afa199 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs @@ -41,13 +41,13 @@ static void Initialize() } } - public static unsafe SocketError TryGetAddrInfo(string name, bool justAddresses, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode) + public static unsafe SocketError TryGetAddrInfo(string name, bool justAddresses, AddressFamily addressFamily, out string? hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode) { Interop.Winsock.EnsureInitialized(); aliases = Array.Empty(); - var hints = new Interop.Winsock.AddressInfo { ai_family = AddressFamily.Unspecified }; // Gets all address families + var hints = new Interop.Winsock.AddressInfo { ai_family = addressFamily }; if (!justAddresses) { hints.ai_flags = AddressInfoHints.AI_CANONNAME; @@ -133,7 +133,7 @@ public static unsafe string GetHostName() return new string((sbyte*)buffer); } - public static unsafe Task GetAddrInfoAsync(string hostName, bool justAddresses) + public static unsafe Task GetAddrInfoAsync(string hostName, bool justAddresses, AddressFamily family, CancellationToken cancellationToken) { Interop.Winsock.EnsureInitialized(); @@ -142,7 +142,7 @@ public static unsafe Task GetAddrInfoAsync(string hostName, bool justAddresses) GetAddrInfoExState state; try { - state = new GetAddrInfoExState(hostName, justAddresses); + state = new GetAddrInfoExState(context, hostName, justAddresses); context->QueryStateHandle = state.CreateHandle(); } catch @@ -151,7 +151,7 @@ public static unsafe Task GetAddrInfoAsync(string hostName, bool justAddresses) throw; } - var hints = new Interop.Winsock.AddressInfoEx { ai_family = AddressFamily.Unspecified }; // Gets all address families + var hints = new Interop.Winsock.AddressInfoEx { ai_family = family }; if (!justAddresses) { hints.ai_flags = AddressInfoHints.AI_CANONNAME; @@ -160,7 +160,11 @@ public static unsafe Task GetAddrInfoAsync(string hostName, bool justAddresses) SocketError errorCode = (SocketError)Interop.Winsock.GetAddrInfoExW( hostName, null, Interop.Winsock.NS_ALL, IntPtr.Zero, &hints, &context->Result, IntPtr.Zero, &context->Overlapped, &GetAddressInfoExCallback, &context->CancelHandle); - if (errorCode != SocketError.IOPending) + if (errorCode == SocketError.IOPending) + { + state.RegisterForCancellation(cancellationToken); + } + else { ProcessResult(errorCode, context); } @@ -183,6 +187,8 @@ private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExCon { GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle); + CancellationToken cancellationToken = state.UnregisterAndGetCancellationToken(); + if (errorCode == SocketError.Success) { IPAddress[] addresses = ParseAddressInfoEx(context->Result, state.JustAddresses, out string? hostName); @@ -197,7 +203,11 @@ private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExCon } else { - state.SetResult(ExceptionDispatchInfo.SetCurrentStackTrace(new SocketException((int)errorCode))); + Exception ex = (errorCode == (SocketError)Interop.Winsock.WSA_E_CANCELLED && cancellationToken.IsCancellationRequested) + ? (Exception)new OperationCanceledException(cancellationToken) + : new SocketException((int)errorCode); + + state.SetResult(ExceptionDispatchInfo.SetCurrentStackTrace(ex)); } } finally @@ -340,14 +350,18 @@ private static unsafe IPAddress CreateIPv6Address(ReadOnlySpan socketAddre return new IPAddress(address, scope); } - private sealed class GetAddrInfoExState : IThreadPoolWorkItem + private sealed unsafe class GetAddrInfoExState : IThreadPoolWorkItem { + private GetAddrInfoExContext* _cancellationContext; + private CancellationTokenRegistration _cancellationRegistration; + private AsyncTaskMethodBuilder IPHostEntryBuilder; private AsyncTaskMethodBuilder IPAddressArrayBuilder; private object? _result; - public GetAddrInfoExState(string hostName, bool justAddresses) + public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool justAddresses) { + _cancellationContext = context; HostName = hostName; JustAddresses = justAddresses; if (justAddresses) @@ -368,6 +382,55 @@ public GetAddrInfoExState(string hostName, bool justAddresses) public Task Task => JustAddresses ? (Task)IPAddressArrayBuilder.Task : IPHostEntryBuilder.Task; + public void RegisterForCancellation(CancellationToken cancellationToken) + { + if (!cancellationToken.CanBeCanceled) return; + + lock (this) + { + if (_cancellationContext == null) + { + // The operation completed before registration could be done. + return; + } + + _cancellationRegistration = cancellationToken.UnsafeRegister(o => + { + var @this = (GetAddrInfoExState)o!; + int cancelResult = 0; + + lock (@this) + { + GetAddrInfoExContext* context = @this._cancellationContext; + + if (context != null) + { + // An outstanding operation will be completed with WSA_E_CANCELLED, and GetAddrInfoExCancel will return NO_ERROR. + // If this thread has lost the race between cancellation and completion, this will be a NOP + // with GetAddrInfoExCancel returning WSA_INVALID_HANDLE. + cancelResult = Interop.Winsock.GetAddrInfoExCancel(&context->CancelHandle); + } + } + + if (cancelResult != 0 && cancelResult != Interop.Winsock.WSA_INVALID_HANDLE && NetEventSource.IsEnabled) + { + NetEventSource.Info(@this, $"GetAddrInfoExCancel returned error {cancelResult}"); + } + }, this); + } + } + + public CancellationToken UnregisterAndGetCancellationToken() + { + lock (this) + { + _cancellationContext = null; + _cancellationRegistration.Unregister(); + } + + return _cancellationRegistration.Token; + } + public void SetResult(object result) { // Store the result and then queue this object to the thread pool to actually complete the Tasks, as we diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs index 9e14de7cf0a63..4151b1070a686 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -109,6 +110,26 @@ public void DnsGetHostAddresses_IPv6String_ReturnsSameIP() Assert.Equal(IPAddress.IPv6Loopback, addresses[0]); } + [Theory] + [MemberData(nameof(IPAndIncorrectFamily_Data))] + public async Task DnsGetHostAddresses_IPStringAndIncorrectFamily_ReturnsNoIPs(bool useAsync, IPAddress address, AddressFamily family) + { + IPAddress[] addresses = + useAsync ? await Dns.GetHostAddressesAsync(address.ToString(), family) : + Dns.GetHostAddresses(address.ToString(), family); + + Assert.Empty(addresses); + } + + public static TheoryData IPAndIncorrectFamily_Data => new TheoryData + { + // useAsync, IP, family + { false, IPAddress.Loopback, AddressFamily.InterNetworkV6 }, + { false, IPAddress.IPv6Loopback, AddressFamily.InterNetwork }, + { true, IPAddress.Loopback, AddressFamily.InterNetworkV6 }, + { true, IPAddress.IPv6Loopback, AddressFamily.InterNetwork } + }; + [Fact] public void DnsGetHostAddresses_LocalHost_ReturnsSameAsGetHostEntry() { @@ -117,5 +138,48 @@ public void DnsGetHostAddresses_LocalHost_ReturnsSameAsGetHostEntry() Assert.Equal(ipEntry.AddressList, addresses); } + + [OuterLoop] + [Theory] + [InlineData(false, TestSettings.IPv4Host, AddressFamily.InterNetwork)] + [InlineData(false, TestSettings.IPv6Host, AddressFamily.InterNetworkV6)] + [InlineData(true, TestSettings.IPv4Host, AddressFamily.InterNetwork)] + [InlineData(true, TestSettings.IPv6Host, AddressFamily.InterNetworkV6)] + public async Task DnsGetHostAddresses_LocalHost_AddressFamilySpecific(bool useAsync, string host, AddressFamily addressFamily) + { + IPAddress[] addresses = + useAsync ? await Dns.GetHostAddressesAsync(host, addressFamily) : + Dns.GetHostAddresses(host, addressFamily); + + Assert.All(addresses, address => Assert.Equal(addressFamily, address.AddressFamily)); + } + + [Fact] + public async Task DnsGetHostAddresses_PreCancelledToken_Throws() + { + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => Dns.GetHostAddressesAsync(TestSettings.LocalHost, cts.Token)); + Assert.Equal(cts.Token, oce.CancellationToken); + } + + [OuterLoop] + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below. + [ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix. + public async Task DnsGetHostAddresses_PostCancelledToken_Throws() + { + using var cts = new CancellationTokenSource(); + + Task task = Dns.GetHostAddressesAsync(TestSettings.UncachedHost, cts.Token); + + // This test might flake if the cancellation token takes too long to trigger: + // It's a race between the DNS server getting back to us and the cancellation processing. + cts.Cancel(); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => task); + Assert.Equal(cts.Token, oce.CancellationToken); + } } } diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs index 4ec2ebe53f15e..65e2ec65d6591 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -228,5 +229,48 @@ public async Task DnsGetHostEntry_LoopbackIP_MatchesGetHostEntryLoopbackString(i Assert.Equal(ipEntry.HostName, stringEntry.HostName); Assert.Equal(ipEntry.AddressList, stringEntry.AddressList); } + + [OuterLoop] + [Theory] + [InlineData(false, TestSettings.IPv4Host, AddressFamily.InterNetwork)] + [InlineData(false, TestSettings.IPv6Host, AddressFamily.InterNetworkV6)] + [InlineData(true, TestSettings.IPv4Host, AddressFamily.InterNetwork)] + [InlineData(true, TestSettings.IPv6Host, AddressFamily.InterNetworkV6)] + public async Task DnsGetHostEntry_LocalHost_AddressFamilySpecific(bool useAsync, string host, AddressFamily addressFamily) + { + IPHostEntry entry = + useAsync ? await Dns.GetHostEntryAsync(host, addressFamily) : + Dns.GetHostEntry(host, addressFamily); + + Assert.All(entry.AddressList, address => Assert.Equal(addressFamily, address.AddressFamily)); + } + + [Fact] + public async Task DnsGetHostEntry_PreCancelledToken_Throws() + { + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => Dns.GetHostEntryAsync(TestSettings.LocalHost, cts.Token)); + Assert.Equal(cts.Token, oce.CancellationToken); + } + + [OuterLoop] + [ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below. + [ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix. + [Fact] + public async Task DnsGetHostEntry_PostCancelledToken_Throws() + { + using var cts = new CancellationTokenSource(); + + Task task = Dns.GetHostEntryAsync(TestSettings.UncachedHost, cts.Token); + + // This test might flake if the cancellation token takes too long to trigger: + // It's a race between the DNS server getting back to us and the cancellation processing. + cts.Cancel(); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => task); + Assert.Equal(cts.Token, oce.CancellationToken); + } } } diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/TestSettings.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/TestSettings.cs index 9f93adf2fccd9..23b452517964a 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/TestSettings.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/TestSettings.cs @@ -13,6 +13,13 @@ public partial class NoParallelTests { } internal static class TestSettings { + // A hostname that will not exist in any DNS caches, forcing some I/O to lookup. + public static string UncachedHost => $"nonexistent-{Guid.NewGuid():N}.contoso.com"; + + public const string IPv4Host = "ipv4.google.com"; + + public const string IPv6Host = "ipv6.google.com"; + public const string LocalHost = "localhost"; public const string LocalIPString = "127.0.0.1"; diff --git a/src/libraries/System.Net.NameResolution/tests/PalTests/NameResolutionPalTests.cs b/src/libraries/System.Net.NameResolution/tests/PalTests/NameResolutionPalTests.cs index 5709b7d058b27..3c5a050a42e52 100644 --- a/src/libraries/System.Net.NameResolution/tests/PalTests/NameResolutionPalTests.cs +++ b/src/libraries/System.Net.NameResolution/tests/PalTests/NameResolutionPalTests.cs @@ -37,7 +37,7 @@ public void HostName_NotNull() [InlineData(true)] public void TryGetAddrInfo_LocalHost(bool justAddresses) { - SocketError error = NameResolutionPal.TryGetAddrInfo("localhost", justAddresses, out string hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); + SocketError error = NameResolutionPal.TryGetAddrInfo("localhost", justAddresses, AddressFamily.Unspecified, out string hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); Assert.Equal(SocketError.Success, error); if (!justAddresses) { @@ -56,7 +56,7 @@ public void TryGetAddrInfo_HostName(bool justAddresses) string hostName = NameResolutionPal.GetHostName(); Assert.NotNull(hostName); - SocketError error = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses, out hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); + SocketError error = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses, AddressFamily.Unspecified, out hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); if (error == SocketError.HostNotFound && (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())) { // On Unix, we are not guaranteed to be able to resove the local host. The ability to do so depends on the @@ -101,7 +101,7 @@ public void TryGetNameInfo_LocalHost_IPv6() [Fact] public void TryGetAddrInfo_LocalHost_TryGetNameInfo() { - SocketError error = NameResolutionPal.TryGetAddrInfo("localhost", justAddresses: false, out string hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); + SocketError error = NameResolutionPal.TryGetAddrInfo("localhost", justAddresses: false, AddressFamily.Unspecified, out string hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); Assert.Equal(SocketError.Success, error); Assert.NotNull(hostName); Assert.NotNull(aliases); @@ -119,7 +119,7 @@ public void TryGetAddrInfo_HostName_TryGetNameInfo() string hostName = NameResolutionPal.GetHostName(); Assert.NotNull(hostName); - SocketError error = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses: false, out hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); + SocketError error = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses: false, AddressFamily.Unspecified, out hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode); if (error == SocketError.HostNotFound) { // On Unix, getaddrinfo returns host not found, if all the machine discovery settings on the local network @@ -153,7 +153,7 @@ public void TryGetAddrInfo_ExternalHost(bool justAddresses) { string hostName = "microsoft.com"; - SocketError error = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses, out hostName, out string[] aliases, out IPAddress[] addresses, out _); + SocketError error = NameResolutionPal.TryGetAddrInfo(hostName, justAddresses, AddressFamily.Unspecified, out hostName, out string[] aliases, out IPAddress[] addresses, out _); Assert.Equal(SocketError.Success, error); Assert.NotNull(aliases); Assert.NotNull(addresses); @@ -168,7 +168,7 @@ public void TryGetNameInfo_LocalHost_IPv4_TryGetAddrInfo(bool justAddresses) Assert.Equal(SocketError.Success, error); Assert.NotNull(name); - error = NameResolutionPal.TryGetAddrInfo(name, justAddresses, out string hostName, out string[] aliases, out IPAddress[] addresses, out _); + error = NameResolutionPal.TryGetAddrInfo(name, justAddresses, AddressFamily.Unspecified, out string hostName, out string[] aliases, out IPAddress[] addresses, out _); Assert.Equal(SocketError.Success, error); Assert.NotNull(aliases); Assert.NotNull(addresses); @@ -190,7 +190,7 @@ public void TryGetNameInfo_LocalHost_IPv6_TryGetAddrInfo(bool justAddresses) Assert.Equal(SocketError.Success, error); Assert.NotNull(name); - error = NameResolutionPal.TryGetAddrInfo(name, justAddresses, out string hostName, out string[] aliases, out IPAddress[] addresses, out _); + error = NameResolutionPal.TryGetAddrInfo(name, justAddresses, AddressFamily.Unspecified, out string hostName, out string[] aliases, out IPAddress[] addresses, out _); if (SocketError.Success != error && Environment.OSVersion.Platform == PlatformID.Unix) { LogUnixInfo(); diff --git a/src/libraries/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs b/src/libraries/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs index 78f031f2fb415..efa509ec0020c 100644 --- a/src/libraries/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs +++ b/src/libraries/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; namespace System.Net @@ -21,7 +22,7 @@ internal static void FakesReset() FakesGetHostByNameCallCount = 0; } - internal static SocketError TryGetAddrInfo(string name, bool justAddresses, out string hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode) + internal static SocketError TryGetAddrInfo(string name, bool justAddresses, AddressFamily addressFamily, out string hostName, out string[] aliases, out IPAddress[] addresses, out int nativeErrorCode) { throw new NotImplementedException(); } @@ -37,7 +38,7 @@ internal static string TryGetNameInfo(IPAddress address, out SocketError errorCo throw new NotImplementedException(); } - internal static Task GetAddrInfoAsync(string hostName, bool justAddresses) + internal static Task GetAddrInfoAsync(string hostName, bool justAddresses, AddressFamily addressFamily, CancellationToken cancellationToken) { throw new NotImplementedException(); }