diff --git a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java index 420d4eacb..e525c9d74 100644 --- a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java +++ b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java @@ -17,7 +17,8 @@ import lombok.EqualsAndHashCode; import lombok.Value; -import org.jspecify.annotations.Nullable; +import lombok.val; +import org.jetbrains.annotations.NotNull; import org.openrewrite.Tree; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.MethodMatcher; @@ -26,54 +27,65 @@ import org.openrewrite.marker.Markers; import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; @Value @EqualsAndHashCode(callSuper = false) public class EqualsAvoidsNullVisitor
extends JavaVisitor
{ - private static final MethodMatcher STRING_EQUALS = new MethodMatcher("String equals(java.lang.Object)"); - private static final MethodMatcher STRING_EQUALS_IGNORE_CASE = new MethodMatcher("String equalsIgnoreCase(java.lang.String)"); + + private static final MethodMatcher EQUALS = new MethodMatcher("java.lang.String equals(java.lang.Object)"); + private static final MethodMatcher EQUALS_IGNORE_CASE = new MethodMatcher("java.lang.String equalsIgnoreCase(java" + + ".lang.String)"); + private static final MethodMatcher COMPARE_TO = new MethodMatcher("java.lang.String compareTo(java.lang.String)"); + private static final MethodMatcher COMPARE_TO_IGNORE_CASE = new MethodMatcher("java.lang.String " + + "compareToIgnoreCase(java.lang.String)"); + private static final MethodMatcher CONTENT_EQUALS = new MethodMatcher("java.lang.String contentEquals(java.lang" + + ".CharSequence)"); EqualsAvoidsNullStyle style; @Override public J visitMethodInvocation(J.MethodInvocation method, P p) { - J j = super.visitMethodInvocation(method, p); - if (!(j instanceof J.MethodInvocation)) { - return j; - } - J.MethodInvocation m = (J.MethodInvocation) j; - if (m.getSelect() == null) { - return m; + val superVisitMethodInvocation = super.visitMethodInvocation(method, p); + if (superVisitMethodInvocation instanceof J.MethodInvocation methodInvocation) { + if (methodInvocation.getSelect() == null) { + return methodInvocation; + } else if (!(methodInvocation.getSelect() instanceof J.Literal) + && methodInvocation.getArguments().get(0) instanceof J.Literal + && EQUALS.matches(methodInvocation) + || !style.getIgnoreEqualsIgnoreCase() + && EQUALS_IGNORE_CASE.matches(methodInvocation) + || COMPARE_TO.matches(methodInvocation) + || COMPARE_TO_IGNORE_CASE.matches(methodInvocation) + || CONTENT_EQUALS.matches(methodInvocation)) { + return visitMethodInvocation(methodInvocation); + } + return methodInvocation; } + return superVisitMethodInvocation; + } - if ((STRING_EQUALS.matches(m) || (!Boolean.TRUE.equals(style.getIgnoreEqualsIgnoreCase()) && STRING_EQUALS_IGNORE_CASE.matches(m))) && - m.getArguments().get(0) instanceof J.Literal && - !(m.getSelect() instanceof J.Literal)) { - Tree parent = getCursor().getParentTreeCursor().getValue(); - if (parent instanceof J.Binary) { - J.Binary binary = (J.Binary) parent; - if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) { - J.Binary potentialNullCheck = (J.Binary) binary.getLeft(); - if ((isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), m.getSelect())) || - (isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), m.getSelect()))) { - doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary)); - } + private @NotNull Expression visitMethodInvocation(final J.MethodInvocation m) { + val parent = getCursor().getParentTreeCursor().getValue(); + if (parent instanceof final J.Binary binary) { + if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof final J.Binary left) { + if (isNullLiteral(left.getLeft()) + && matchesSelect(left.getRight(), requireNonNull(m.getSelect())) + || (isNullLiteral(left.getRight()) + && matchesSelect(left.getLeft(), + requireNonNull(m.getSelect())))) { + doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary)); } } - - if (m.getArguments().get(0).getType() == JavaType.Primitive.Null) { - return new J.Binary(Tree.randomId(), m.getPrefix(), Markers.EMPTY, - m.getSelect(), - JLeftPadded.build(J.Binary.Type.Equal).withBefore(Space.SINGLE_SPACE), - m.getArguments().get(0).withPrefix(Space.SINGLE_SPACE), - JavaType.Primitive.Boolean); - } else { - m = m.withSelect(((J.Literal) m.getArguments().get(0)).withPrefix(m.getSelect().getPrefix())) - .withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY))); - } + } else if (m.getArguments().get(0).getType() == JavaType.Primitive.Null) { + return new J.Binary(Tree.randomId(), m.getPrefix(), Markers.EMPTY, + requireNonNull(m.getSelect()), + JLeftPadded.build(J.Binary.Type.Equal).withBefore(Space.SINGLE_SPACE), + m.getArguments().get(0).withPrefix(Space.SINGLE_SPACE), + JavaType.Primitive.Boolean); } - - return m; + return m.withSelect(m.getArguments().get(0).withPrefix(requireNonNull(m.getSelect()).getPrefix())) + .withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY))); } private boolean isNullLiteral(Expression expression) { @@ -88,14 +100,6 @@ private static class RemoveUnnecessaryNullCheck
extends JavaVisitor
{ private final J.Binary scope; boolean done; - @Override - public @Nullable J visit(@Nullable Tree tree, P p) { - if (done) { - return (J) tree; - } - return super.visit(tree, p); - } - public RemoveUnnecessaryNullCheck(J.Binary scope) { this.scope = scope; } @@ -106,7 +110,6 @@ public J visitBinary(J.Binary binary, P p) { done = true; return binary.getRight().withPrefix(Space.EMPTY); } - return super.visitBinary(binary, p); } } diff --git a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitorTest.java similarity index 86% rename from src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java rename to src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitorTest.java index 68bb55236..620e296d6 100644 --- a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitorTest.java @@ -42,6 +42,9 @@ public class A { String s = null; if(s.equals("test")) {} if(s.equalsIgnoreCase("test")) {} + System.out.println(s.compareTo("test")); + System.out.println(s.compareToIgnoreCase("test")); + System.out.println(s.contentEquals("test")); } } """, @@ -51,6 +54,9 @@ public class A { String s = null; if("test".equals(s)) {} if("test".equalsIgnoreCase(s)) {} + System.out.println("test".compareTo(s)); + System.out.println("test".compareToIgnoreCase(s)); + System.out.println("test".contentEquals(s)); } } """ @@ -88,8 +94,8 @@ public class A { @Test void nullLiteral() { rewriteRun( - //language=java - java(""" + //language=java + java(""" public class A { void foo(String s) { if(s.equals(null)) { @@ -97,8 +103,8 @@ void foo(String s) { } } """, - """ - + """ + public class A { void foo(String s) { if(s == null) { @@ -108,4 +114,5 @@ void foo(String s) { """) ); } + }