diff --git a/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java b/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java index ea824ebd23a9..478b3372add0 100644 --- a/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java +++ b/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java @@ -1407,11 +1407,11 @@ private static Optional findMethod(Class clazz, Predicate pre for (Class current = clazz; isSearchable(current); current = current.getSuperclass()) { // Search for match in current type - List methods = current.isInterface() ? getMethods(current) : getDeclaredMethods(current, BOTTOM_UP); - for (Method method : methods) { - if (predicate.test(method)) { - return Optional.of(method); - } + List methods = current.isInterface() ? getMethods(current, predicate) + : getDeclaredMethods(current, predicate, BOTTOM_UP); + if (!methods.isEmpty()) { + // Since the predicate has already been applied, return the first match. + return Optional.of(methods.get(0)); } // Search for match in interfaces implemented by current type @@ -1506,8 +1506,8 @@ private static List findAllMethodsInHierarchy(Class clazz, Predicate< Preconditions.notNull(traversalMode, "HierarchyTraversalMode must not be null"); // @formatter:off - List localMethods = getDeclaredMethods(clazz, traversalMode).stream() - .filter(predicate.and(method -> !method.isSynthetic())) + List localMethods = getDeclaredMethods(clazz, predicate, traversalMode).stream() + .filter(method -> !method.isSynthetic()) .collect(toList()); List superclassMethods = getSuperclassMethods(clazz, predicate, traversalMode).stream() .filter(method -> !isMethodShadowedByLocalMethods(method, localMethods)) @@ -1533,7 +1533,6 @@ private static List findAllMethodsInHierarchy(Class clazz, Predicate< /** * Custom alternative to {@link Class#getFields()} that sorts the fields * which match the supplied predicate and converts them to a mutable list. - * @param predicate the field filter; never {@code null} */ private static List getFields(Class clazz, Predicate predicate) { return toSortedMutableList(clazz.getFields(), predicate); @@ -1542,7 +1541,6 @@ private static List getFields(Class clazz, Predicate predicate) /** * Custom alternative to {@link Class#getDeclaredFields()} that sorts the * fields which match the supplied predicate and converts them to a mutable list. - * @param predicate the field filter; never {@code null} */ private static List getDeclaredFields(Class clazz, Predicate predicate) { return toSortedMutableList(clazz.getDeclaredFields(), predicate); @@ -1550,24 +1548,26 @@ private static List getDeclaredFields(Class clazz, Predicate pr /** * Custom alternative to {@link Class#getMethods()} that sorts the methods - * and converts them to a mutable list. + * which match the supplied predicate and converts them to a mutable list. */ - private static List getMethods(Class clazz) { - return toSortedMutableList(clazz.getMethods()); + private static List getMethods(Class clazz, Predicate predicate) { + return toSortedMutableList(clazz.getMethods(), predicate); } /** * Custom alternative to {@link Class#getDeclaredMethods()} that sorts the - * methods and converts them to a mutable list. + * methods which match the supplied predicate and converts them to a mutable list. * *

In addition, the list returned by this method includes interface * default methods which are either prepended or appended to the list of * declared methods depending on the supplied traversal mode. */ - private static List getDeclaredMethods(Class clazz, HierarchyTraversalMode traversalMode) { + private static List getDeclaredMethods(Class clazz, Predicate predicate, + HierarchyTraversalMode traversalMode) { + // Note: getDefaultMethods() already sorts the methods, - List defaultMethods = getDefaultMethods(clazz); - List declaredMethods = toSortedMutableList(clazz.getDeclaredMethods()); + List defaultMethods = getDefaultMethods(clazz, predicate); + List declaredMethods = toSortedMutableList(clazz.getDeclaredMethods(), predicate); // Take the traversal mode into account in order to retain the inherited // nature of interface default methods. @@ -1584,23 +1584,23 @@ private static List getDeclaredMethods(Class clazz, HierarchyTraversa /** * Get a sorted, mutable list of all default methods present in interfaces * implemented by the supplied class which are also visible within - * the supplied class. + * the supplied class and match the supplied predicate. * * @see Method Visibility * in the Java Language Specification */ - private static List getDefaultMethods(Class clazz) { + private static List getDefaultMethods(Class clazz, Predicate predicate) { // @formatter:off // Visible default methods are interface default methods that have not // been overridden. List visibleDefaultMethods = Arrays.stream(clazz.getMethods()) - .filter(Method::isDefault) + .filter(predicate.and(Method::isDefault)) .collect(toCollection(ArrayList::new)); if (visibleDefaultMethods.isEmpty()) { return visibleDefaultMethods; } return Arrays.stream(clazz.getInterfaces()) - .map(ReflectionUtils::getMethods) + .map(ifc -> getMethods(ifc, predicate)) .flatMap(List::stream) .filter(visibleDefaultMethods::contains) .collect(toCollection(ArrayList::new)); @@ -1617,9 +1617,10 @@ private static List toSortedMutableList(Field[] fields, Predicate // @formatter:on } - private static List toSortedMutableList(Method[] methods) { + private static List toSortedMutableList(Method[] methods, Predicate predicate) { // @formatter:off return Arrays.stream(methods) + .filter(predicate) .sorted(ReflectionUtils::defaultMethodSorter) // Use toCollection() instead of toList() to ensure list is mutable. .collect(toCollection(ArrayList::new)); @@ -1658,8 +1659,8 @@ private static List getInterfaceMethods(Class clazz, Predicate ifc : clazz.getInterfaces()) { // @formatter:off - List localInterfaceMethods = getMethods(ifc).stream() - .filter(predicate.and(method -> !isAbstract(method))) + List localInterfaceMethods = getMethods(ifc, predicate).stream() + .filter(method -> !isAbstract(method)) .collect(toList()); List superinterfaceMethods = getInterfaceMethods(ifc, predicate, traversalMode).stream()