diff --git a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java index c8b207e9..6fb2a463 100644 --- a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java +++ b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java @@ -153,6 +153,12 @@ private static class ExpressionAndType { private final JavaType type; } + @Data + private static class VariableAndTypeTree { + private final J.VariableDeclarations.NamedVariable variable; + private final TypeTree type; + } + @Data private static class InstanceOfPatternReplacements { private final J root; @@ -160,7 +166,7 @@ private static class InstanceOfPatternReplacements { private final Map> contexts = new HashMap<>(); private final Map> contextScopes = new HashMap<>(); private final Map replacements = new HashMap<>(); - private final Map variablesToDelete = new HashMap<>(); + private final Map variablesToDelete = new HashMap<>(); public void registerInstanceOf(J.InstanceOf instanceOf, Set contexts) { Expression expression = instanceOf.getExpression(); @@ -192,13 +198,21 @@ public void registerTypeCast(J.TypeCast typeCast, Cursor cursor) { for (Iterator it = cursor.getPath(); it.hasNext(); ) { Object next = it.next(); if (validContexts.contains(next)) { - if (parent.getValue() instanceof J.VariableDeclarations.NamedVariable - && !variablesToDelete.containsKey(instanceOf)) { - variablesToDelete.put(instanceOf, parent.getValue()); + if (isAcceptableTypeCast(typeCast) && isTheSameAsOtherTypeCasts(typeCast, instanceOf)) { + if (parent.getValue() instanceof J.VariableDeclarations.NamedVariable + && !variablesToDelete.containsKey(instanceOf)) { + variablesToDelete.put(instanceOf, new VariableAndTypeTree(parent.getValue(), parent.firstEnclosing(J.VariableDeclarations.class).getTypeExpression())); + } else { + replacements.put(typeCast, instanceOf); + } + contextScopes.computeIfAbsent(instanceOf, k -> new HashSet<>()).add(cursor); } else { - replacements.put(typeCast, instanceOf); + replacements.entrySet().removeIf(e -> e.getValue() == instanceOf); + variablesToDelete.remove(instanceOf); + contextScopes.remove(instanceOf); + contexts.remove(instanceOf); + instanceOfs.entrySet().removeIf(e -> e.getValue() == instanceOf); } - contextScopes.computeIfAbsent(instanceOf, k -> new HashSet<>()).add(cursor); break; } else if (root == next) { break; @@ -207,6 +221,24 @@ public void registerTypeCast(J.TypeCast typeCast, Cursor cursor) { } } + private boolean isAcceptableTypeCast(J.TypeCast typeCast) { + TypeTree typeTree = typeCast.getClazz().getTree(); + if (typeTree instanceof J.ParameterizedType) { + return ((J.ParameterizedType) typeTree).getTypeParameters().stream().allMatch(J.Wildcard.class::isInstance); + } + return true; + } + + private boolean isTheSameAsOtherTypeCasts(J.TypeCast typeCast, J.InstanceOf instanceOf) { + return replacements + .entrySet() + .stream() + .filter(e -> e.getValue() == instanceOf) + .findFirst() + .map(e -> e.getKey().getType().equals(typeCast.getType())) + .orElse(true); + } + public boolean isEmpty() { return replacements.isEmpty() && variablesToDelete.isEmpty(); } @@ -225,24 +257,14 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) { name, type, null)); - JavaType.FullyQualified fqType = TypeUtils.asFullyQualified(type); - if (fqType != null && !fqType.getTypeParameters().isEmpty() && !(instanceOf.getClazz() instanceof J.ParameterizedType)) { - TypedTree oldTypeTree = (TypedTree) instanceOf.getClazz(); - // Each type parameter is turned into a wildcard, i.e. `List` -> `List` or `Map.Entry` -> `Map.Entry` - List wildcardsList = IntStream.range(0, fqType.getTypeParameters().size()) - .mapToObj(i -> new J.Wildcard(randomId(), Space.EMPTY, Markers.EMPTY, null, null)) - .collect(Collectors.toList()); - J.ParameterizedType newTypeTree = new J.ParameterizedType( - randomId(), - oldTypeTree.getPrefix(), - Markers.EMPTY, - oldTypeTree.withPrefix(Space.EMPTY), - null, - oldTypeTree.getType() - ).withTypeParameters(wildcardsList); - result = result.withClazz(newTypeTree); + J currentTypeTree = instanceOf.getClazz(); + TypeTree typeCastTypeTree = computeTypeTreeFromTypeCasts(instanceOf); + // If type tree from typa cast is not parameterized then NVM. Instance of should already have proper type + if (typeCastTypeTree != null && typeCastTypeTree instanceof J.ParameterizedType) { + J.ParameterizedType parameterizedType = (J.ParameterizedType) typeCastTypeTree; + result = result.withClazz(parameterizedType.withId(Tree.randomId()).withPrefix(currentTypeTree.getPrefix())); } // update entry in replacements to share the pattern variable name @@ -254,12 +276,29 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) { return result; } + private TypeTree computeTypeTreeFromTypeCasts(J.InstanceOf instanceOf) { + TypeTree typeCastTypeTree = replacements + .entrySet() + .stream() + .filter(e -> e.getValue() == instanceOf) + .findFirst() + .map(e -> e.getKey().getClazz().getTree()) + .orElse(null); + if (typeCastTypeTree == null) { + VariableAndTypeTree variable = variablesToDelete.get(instanceOf); + if (variable != null) { + typeCastTypeTree = variable.getType(); + } + } + return typeCastTypeTree; + } + private String patternVariableName(J.InstanceOf instanceOf, Cursor cursor) { VariableNameStrategy strategy; if (root instanceof J.If) { - J.VariableDeclarations.NamedVariable variable = variablesToDelete.get(instanceOf); - strategy = variable != null - ? VariableNameStrategy.exact(variable.getSimpleName()) + VariableAndTypeTree variableData = variablesToDelete.get(instanceOf); + strategy = variableData != null + ? VariableNameStrategy.exact(variableData.getVariable().getSimpleName()) : VariableNameStrategy.normal(contextScopes.get(instanceOf)); } else { strategy = VariableNameStrategy.short_(); @@ -288,7 +327,7 @@ private String patternVariableName(J.InstanceOf instanceOf, Cursor cursor) { } public @Nullable J processVariableDeclarations(J.VariableDeclarations multiVariable) { - return multiVariable.getVariables().stream().anyMatch(variablesToDelete::containsValue) ? null : multiVariable; + return multiVariable.getVariables().stream().anyMatch(v -> variablesToDelete.values().stream().anyMatch(vd -> vd.getVariable() == v)) ? null : multiVariable; } } diff --git a/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java b/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java index 8279511d..cfcbc9d0 100644 --- a/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java @@ -133,7 +133,7 @@ void test(Object o) { } @Test - void genericsWithoutParameters() { + void typeParameters_1() { rewriteRun( //language=java java( @@ -142,21 +142,48 @@ void genericsWithoutParameters() { import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import java.util.stream.Stream; public class A { @SuppressWarnings("unchecked") - public static List> applyRoutesType(Object routes) { + public static Stream> applyRoutesType(Object routes) { if (routes instanceof List) { List routesList = (List) routes; if (routesList.isEmpty()) { - return Collections.emptyList(); + return Stream.empty(); } if (routesList.stream() .anyMatch(route -> !(route instanceof Map))) { - return Collections.emptyList(); + return Stream.empty(); } return routesList.stream() - .map(route -> (Map) route) - .collect(Collectors.toList()); + .map(route -> (Map) route); + } + return Stream.empty(); + } + } + """ + ) + ); + } + + @Test + void typeParameters_2() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + List routesList = (List) routes; + if (routesList.isEmpty()) { + return Collections.emptyList(); + } } return Collections.emptyList(); } @@ -170,7 +197,111 @@ public static List> applyRoutesType(Object routes) { public class A { @SuppressWarnings("unchecked") public static List> applyRoutesType(Object routes) { - if (routes instanceof List routesList) { + if (routes instanceof List routesList) { + if (routesList.isEmpty()) { + return Collections.emptyList(); + } + } + return Collections.emptyList(); + } + } + """ + ) + ); + } + + @Test + void typeParameters_3() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + public class A { + @SuppressWarnings("unchecked") + public static void applyRoutesType(Object routes) { + if (routes instanceof List) { + List routesList = (List) routes; + String.join(",", (List) routes); + } + } + } + """ + ) + ); + } + + @Test + void typeParameters_4() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + public class A { + @SuppressWarnings("unchecked") + public static void applyRoutesType(Object routes) { + if (routes instanceof List) { + String.join(",", (List) routes); + } + } + } + """, """ + import java.util.Collections; + import java.util.List; + public class A { + @SuppressWarnings("unchecked") + public static void applyRoutesType(Object routes) { + if (routes instanceof List list) { + String.join(",", list); + } + } + } + """ + ) + ); + } + + @Test + void typeParameters_5() { + rewriteRun( + //language=java + java( + """ + import java.util.Arrays; + import java.util.Collection; + import java.util.List; + public class A { + @SuppressWarnings("unchecked") + private Collection addValueToList(List previousValues, Object value) { + if (previousValues == null) { + return (value instanceof Collection) ? (Collection) value : Arrays.asList(value); + } + return List.of(); + } + } + """ + ) + ); + } + + @Test + void typeParameters_6() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + List routesList = (List) routes; if (routesList.isEmpty()) { return Collections.emptyList(); } @@ -190,6 +321,76 @@ public static List> applyRoutesType(Object routes) { ); } + @Test + void typeParameters_7() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + return ((List) routes).stream() + .map(route -> (Map) route) + .collect(Collectors.toList()); + } + return Collections.emptyList(); + } + } + """, """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List list) { + return list.stream() + .map(route -> (Map) route) + .collect(Collectors.toList()); + } + return Collections.emptyList(); + } + } + """ + ) + ); + } + + @Test + void typeParameters_8() { + rewriteRun( + //language=java + java( + """ + import java.util.Arrays; + import java.util.Collection; + import java.util.List; + public class A { + @SuppressWarnings("unchecked") + private Collection addValueToList(List previousValues, Object value) { + Collection cl = List.of(); + if (previousValues == null) { + if (value instanceof Collection) { + cl = (Collection) value; + } else { + cl = Arrays.asList(value.toString()); + } + } + return cl; + } + } + """ + ) + ); + } + @Test void primitiveArray() { rewriteRun( @@ -256,7 +457,7 @@ void conflictingVariableInBody() { public class A { void test(Object o) { if (o instanceof String) { - String string = 'x'; + String string = "x"; System.out.println((String) o); // String string1 = "y"; } @@ -267,7 +468,7 @@ void test(Object o) { public class A { void test(Object o) { if (o instanceof String string1) { - String string = 'x'; + String string = "x"; System.out.println(string1); // String string1 = "y"; } @@ -302,7 +503,7 @@ void test(Object o) { public class A { void test(Object o) { Map.Entry entry = null; - if (o instanceof Map.Entry entry1) { + if (o instanceof Map.Entry entry1) { entry = entry1; } System.out.println(entry); @@ -869,7 +1070,7 @@ Object test(Object o) { return o instanceof List ? ((List) o).get(0) : o.toString(); } } - """, + """/*, """ import java.util.List; public class A { @@ -877,7 +1078,7 @@ Object test(Object o) { return o instanceof List l ? l.get(0) : o.toString(); } } - """ + """*/ ) ); } @@ -975,6 +1176,52 @@ String test(Object o) { ) ); } + @Test + void iterableParameter() { + rewriteRun( + //language=java + java( + """ + import java.util.HashMap; + import java.util.List; + import java.util.Map; + + public class ApplicationSecurityGroupsParameterHelper { + + static final String APPLICATION_SECURITY_GROUPS = "application-security-groups"; + + public Map transformGatewayParameters(Map parameters) { + Map environment = new HashMap<>(); + Object applicationSecurityGroups = parameters.get(APPLICATION_SECURITY_GROUPS); + if (applicationSecurityGroups instanceof List) { + environment.put(APPLICATION_SECURITY_GROUPS, String.join(",", (List) applicationSecurityGroups)); + } + return environment; + } + } + """, + """ + import java.util.HashMap; + import java.util.List; + import java.util.Map; + + public class ApplicationSecurityGroupsParameterHelper { + + static final String APPLICATION_SECURITY_GROUPS = "application-security-groups"; + + public Map transformGatewayParameters(Map parameters) { + Map environment = new HashMap<>(); + Object applicationSecurityGroups = parameters.get(APPLICATION_SECURITY_GROUPS); + if (applicationSecurityGroups instanceof List list) { + environment.put(APPLICATION_SECURITY_GROUPS, String.join(",", list)); + } + return environment; + } + } + """ + ) + ); + } } @Nested