Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PMD: LiteralsFirstInComparisons for compareTo* and contentEquals #365

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import lombok.EqualsAndHashCode;
import lombok.Value;
import org.jspecify.annotations.Nullable;
import lombok.val;
import org.jetbrains.annotations.NotNull;
import org.openrewrite.Tree;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.MethodMatcher;
Expand All @@ -26,54 +27,65 @@
import org.openrewrite.marker.Markers;

import static java.util.Collections.singletonList;
import static java.util.Objects.requireNonNull;

@Value
@EqualsAndHashCode(callSuper = false)
public class EqualsAvoidsNullVisitor<P> extends JavaVisitor<P> {
private static final MethodMatcher STRING_EQUALS = new MethodMatcher("String equals(java.lang.Object)");
private static final MethodMatcher STRING_EQUALS_IGNORE_CASE = new MethodMatcher("String equalsIgnoreCase(java.lang.String)");

private static final MethodMatcher EQUALS = new MethodMatcher("java.lang.String equals(java.lang.Object)");
private static final MethodMatcher EQUALS_IGNORE_CASE = new MethodMatcher("java.lang.String equalsIgnoreCase(java" +
".lang.String)");
private static final MethodMatcher COMPARE_TO = new MethodMatcher("java.lang.String compareTo(java.lang.String)");
private static final MethodMatcher COMPARE_TO_IGNORE_CASE = new MethodMatcher("java.lang.String " +
"compareToIgnoreCase(java.lang.String)");
private static final MethodMatcher CONTENT_EQUALS = new MethodMatcher("java.lang.String contentEquals(java.lang" +
".CharSequence)");

EqualsAvoidsNullStyle style;

@Override
public J visitMethodInvocation(J.MethodInvocation method, P p) {
J j = super.visitMethodInvocation(method, p);
if (!(j instanceof J.MethodInvocation)) {
return j;
}
J.MethodInvocation m = (J.MethodInvocation) j;
if (m.getSelect() == null) {
return m;
val superVisitMethodInvocation = super.visitMethodInvocation(method, p);
if (superVisitMethodInvocation instanceof J.MethodInvocation methodInvocation) {
punkratz312 marked this conversation as resolved.
Show resolved Hide resolved
if (methodInvocation.getSelect() == null) {
return methodInvocation;
} else if (!(methodInvocation.getSelect() instanceof J.Literal)
&& methodInvocation.getArguments().get(0) instanceof J.Literal
&& EQUALS.matches(methodInvocation)
|| !style.getIgnoreEqualsIgnoreCase()
&& EQUALS_IGNORE_CASE.matches(methodInvocation)
|| COMPARE_TO.matches(methodInvocation)
|| COMPARE_TO_IGNORE_CASE.matches(methodInvocation)
|| CONTENT_EQUALS.matches(methodInvocation)) {
return visitMethodInvocation(methodInvocation);
}
return methodInvocation;
}
return superVisitMethodInvocation;
}

if ((STRING_EQUALS.matches(m) || (!Boolean.TRUE.equals(style.getIgnoreEqualsIgnoreCase()) && STRING_EQUALS_IGNORE_CASE.matches(m))) &&
m.getArguments().get(0) instanceof J.Literal &&
!(m.getSelect() instanceof J.Literal)) {
Tree parent = getCursor().getParentTreeCursor().getValue();
if (parent instanceof J.Binary) {
J.Binary binary = (J.Binary) parent;
if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) {
J.Binary potentialNullCheck = (J.Binary) binary.getLeft();
if ((isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), m.getSelect())) ||
(isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), m.getSelect()))) {
doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary));
}
private @NotNull Expression visitMethodInvocation(final J.MethodInvocation m) {
val parent = getCursor().getParentTreeCursor().getValue();
if (parent instanceof final J.Binary binary) {
if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof final J.Binary left) {
if (isNullLiteral(left.getLeft())
&& matchesSelect(left.getRight(), requireNonNull(m.getSelect()))
|| (isNullLiteral(left.getRight())
&& matchesSelect(left.getLeft(),
requireNonNull(m.getSelect())))) {
doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary));
}
}

if (m.getArguments().get(0).getType() == JavaType.Primitive.Null) {
return new J.Binary(Tree.randomId(), m.getPrefix(), Markers.EMPTY,
m.getSelect(),
JLeftPadded.build(J.Binary.Type.Equal).withBefore(Space.SINGLE_SPACE),
m.getArguments().get(0).withPrefix(Space.SINGLE_SPACE),
JavaType.Primitive.Boolean);
} else {
m = m.withSelect(((J.Literal) m.getArguments().get(0)).withPrefix(m.getSelect().getPrefix()))
.withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY)));
}
} else if (m.getArguments().get(0).getType() == JavaType.Primitive.Null) {
return new J.Binary(Tree.randomId(), m.getPrefix(), Markers.EMPTY,
requireNonNull(m.getSelect()),
JLeftPadded.build(J.Binary.Type.Equal).withBefore(Space.SINGLE_SPACE),
m.getArguments().get(0).withPrefix(Space.SINGLE_SPACE),
JavaType.Primitive.Boolean);
}

return m;
return m.withSelect(m.getArguments().get(0).withPrefix(requireNonNull(m.getSelect()).getPrefix()))
.withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY)));
}

private boolean isNullLiteral(Expression expression) {
Expand All @@ -88,14 +100,6 @@ private static class RemoveUnnecessaryNullCheck<P> extends JavaVisitor<P> {
private final J.Binary scope;
boolean done;

@Override
public @Nullable J visit(@Nullable Tree tree, P p) {
if (done) {
return (J) tree;
}
return super.visit(tree, p);
}

public RemoveUnnecessaryNullCheck(J.Binary scope) {
this.scope = scope;
}
Expand All @@ -106,7 +110,6 @@ public J visitBinary(J.Binary binary, P p) {
done = true;
return binary.getRight().withPrefix(Space.EMPTY);
}

return super.visitBinary(binary, p);
}
}
Expand Down
punkratz312 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ public class A {
String s = null;
if(s.equals("test")) {}
if(s.equalsIgnoreCase("test")) {}
System.out.println(s.compareTo("test"));
System.out.println(s.compareToIgnoreCase("test"));
System.out.println(s.contentEquals("test"));
}
}
""",
Expand All @@ -51,6 +54,9 @@ public class A {
String s = null;
if("test".equals(s)) {}
if("test".equalsIgnoreCase(s)) {}
System.out.println("test".compareTo(s));
System.out.println("test".compareToIgnoreCase(s));
System.out.println("test".contentEquals(s));
}
}
"""
Expand Down Expand Up @@ -88,17 +94,17 @@ public class A {
@Test
void nullLiteral() {
rewriteRun(
//language=java
java("""
//language=java
java("""
public class A {
void foo(String s) {
if(s.equals(null)) {
}
}
}
""",
"""

"""
public class A {
void foo(String s) {
if(s == null) {
Expand All @@ -108,4 +114,5 @@ void foo(String s) {
""")
);
}

}
Loading