From 21c7bfa2001f88a7ac49c80d84b29ffc7cd1592d Mon Sep 17 00:00:00 2001 From: James Newton-King <james@newtonking.com> Date: Thu, 1 Sep 2022 01:02:54 +0800 Subject: [PATCH] [release/7.0] Improve port search retry logic (#43557) * [release/7.0] Improve port search retry logic * Update * Improvements * Comment * Logging * PR feedback --- .../Kestrel/shared/test/ServerRetryHelper.cs | 71 ++++++++++++------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/src/Servers/Kestrel/shared/test/ServerRetryHelper.cs b/src/Servers/Kestrel/shared/test/ServerRetryHelper.cs index d211dde1fd4..d2491c75237 100644 --- a/src/Servers/Kestrel/shared/test/ServerRetryHelper.cs +++ b/src/Servers/Kestrel/shared/test/ServerRetryHelper.cs @@ -2,34 +2,39 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net; -using System.Net.Sockets; +using System.Net.NetworkInformation; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Testing; public static class ServerRetryHelper { - private const int RetryCount = 10; + private const int RetryCount = 20; /// <summary> /// Retry a func. Useful when a test needs an explicit port and you want to avoid port conflicts. /// </summary> public static async Task BindPortsWithRetry(Func<int, Task> retryFunc, ILogger logger) { - var ports = GetFreePorts(RetryCount); - var retryCount = 0; + + // Add a random number to starting port to reduce chance of conflicts because of multiple tests using this retry. + var nextPortAttempt = 5000 + Random.Shared.Next(500); + while (true) { + // Find a port that's available for TCP and UDP. Start with port 5000 and search upwards from there. + var port = GetAvailablePort(nextPortAttempt, logger); try { - await retryFunc(ports[retryCount]); + await retryFunc(port); break; } catch (Exception ex) { retryCount++; + nextPortAttempt = port + 1; if (retryCount >= RetryCount) { @@ -43,34 +48,50 @@ public static class ServerRetryHelper } } - private static int[] GetFreePorts(int count) + private static int GetAvailablePort(int startingPort, ILogger logger) { - var sockets = new List<Socket>(); + logger.LogInformation($"Searching for free port starting at {startingPort}."); - for (var i = 0; i < count; i++) - { - // Find a port that's free by binding port 0. - // Note that this port should be free when the test runs, but: - // - Something else could steal it before the test uses it. - // - UDP port with the same number could be in use. - // For that reason, some retries should be available. - var ipEndPoint = new IPEndPoint(IPAddress.Loopback, 0); - var listenSocket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + var unavailableEndpoints = new List<IPEndPoint>(); - listenSocket.Bind(ipEndPoint); + var properties = IPGlobalProperties.GetIPGlobalProperties(); - sockets.Add(listenSocket); - } + // Ignore active connections + AddEndpoints(startingPort, unavailableEndpoints, properties.GetActiveTcpConnections().Select(c => c.LocalEndPoint)); - // Ports are calculated upfront. Rebinding with port 0 could result the same port - // being returned for each retry. - var ports = sockets.Select(s => (IPEndPoint)s.LocalEndPoint).Select(ep => ep.Port).ToArray(); + // Ignore active tcp listners + AddEndpoints(startingPort, unavailableEndpoints, properties.GetActiveTcpListeners()); - foreach (var socket in sockets) + // Ignore active UDP listeners + AddEndpoints(startingPort, unavailableEndpoints, properties.GetActiveUdpListeners()); + + logger.LogInformation($"Found {unavailableEndpoints.Count} unavailable endpoints."); + + for (var i = startingPort; i < ushort.MaxValue; i++) { - socket.Dispose(); + var match = unavailableEndpoints.FirstOrDefault(ep => ep.Port == i); + if (match == null) + { + logger.LogInformation($"Port {i} free."); + return i; + } + else + { + logger.LogInformation($"Port {i} in use. End point: {match}"); + } } - return ports; + throw new Exception($"Couldn't find a free port after {startingPort}."); + + static void AddEndpoints(int startingPort, List<IPEndPoint> endpoints, IEnumerable<IPEndPoint> activeEndpoints) + { + foreach (IPEndPoint endpoint in activeEndpoints) + { + if (endpoint.Port >= startingPort) + { + endpoints.Add(endpoint); + } + } + } } } -- GitLab