Skip to content

fix RemoveUnusedImportsStep leftovers #1337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@

package com.palantir.javaformat.java;

import static java.lang.Math.max;
import static java.nio.charset.StandardCharsets.UTF_8;

import com.google.common.base.CharMatcher;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multimap;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeMap;
import com.google.common.collect.TreeRangeSet;
import com.palantir.javaformat.Newlines;
import com.sun.source.doctree.DocCommentTree;
import com.sun.source.doctree.ReferenceTree;
Expand All @@ -36,20 +38,30 @@
import com.sun.source.util.TreePathScanner;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.api.JavacTrees;
import com.sun.tools.javac.file.JavacFileManager;
import com.sun.tools.javac.parser.ParserFactory;
import com.sun.tools.javac.tree.DCTree;
import com.sun.tools.javac.tree.DCTree.DCReference;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCImport;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.Log;
import com.sun.tools.javac.util.Options;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import javax.tools.DiagnosticCollector;
import javax.tools.DiagnosticListener;
import javax.tools.JavaFileObject;
import javax.tools.SimpleJavaFileObject;
import javax.tools.StandardLocation;

/**
* Removes unused imports from a source file. Imports that are only used in javadoc are also removed, and the references
Expand All @@ -76,15 +88,12 @@ public class RemoveUnusedImports {
private static final class UnusedImportScanner extends TreePathScanner<Void, Void> {

private final Set<String> usedNames = new LinkedHashSet<>();

private final Multimap<String, Range<Integer>> usedInJavadoc = HashMultimap.create();

final JavacTrees trees;
final DocTreeScanner docTreeSymbolScanner;
private final DocTreeScanner docTreeSymbolScanner = new DocTreeScanner();
private final JavacTrees trees;

private UnusedImportScanner(JavacTrees trees) {
this.trees = trees;
docTreeSymbolScanner = new DocTreeScanner();
}

/** Skip the imports themselves when checking for usage. */
Expand Down Expand Up @@ -202,21 +211,50 @@ public Void visitIdentifier(IdentifierTree node, Void aVoid) {
}
}

public static String removeUnusedImports(final String contents) throws FormatterException {
public static String removeUnusedImports(final String contents) {
Context context = new Context();
JCCompilationUnit unit = parse(context, contents);
if (unit == null) {
// error handling is done during formatting
return contents;
}
UnusedImportScanner scanner = new UnusedImportScanner(JavacTrees.instance(context));
scanner.scan(unit, null);
return applyReplacements(contents, buildReplacements(contents, unit, scanner.usedNames, scanner.usedInJavadoc));

String s = contents;

// Normalize newlines while preserving important blank lines
String sep = Newlines.guessLineSeparator(contents);

// Ensure exactly one blank line after package declaration
s = s.replaceAll("(?m)^(package .+)" + sep + "\\s+" + sep, "$1" + sep + sep);

// Ensure exactly one blank line between last import and class declaration
s = s.replaceAll("(?m)^(import .+)" + sep + "\\s+" + sep + "(?=class|interface|enum|record)", "$1" + sep + sep);

// Remove multiple blank lines elsewhere in imports section
s = s.replaceAll("(?m)^(import .+)" + sep + "\\s+" + sep + "(?=import)", "$1" + sep);

// Apply replacements last, after formatting
return applyReplacements(s, buildReplacements(s, unit, scanner.usedNames, scanner.usedInJavadoc));
}

private static JCCompilationUnit parse(Context context, String javaInput) throws FormatterException {
private static JCCompilationUnit parse(Context context, String javaInput) {
context.put(DiagnosticListener.class, new DiagnosticCollector<JavaFileObject>());
Options.instance(context).put("allowStringFolding", "false");
return Formatter.parseJcCompilationUnit(context, javaInput);
try (JavacFileManager fileManager = new JavacFileManager(context, true, UTF_8)) {
fileManager.setLocation(StandardLocation.PLATFORM_CLASS_PATH, ImmutableList.of());
} catch (IOException e) {
throw new RuntimeException(e);
}
SimpleJavaFileObject source = new SimpleJavaFileObject(URI.create("source"), JavaFileObject.Kind.SOURCE) {
@Override
public CharSequence getCharContent(boolean ignoreEncodingErrors) {
return javaInput;
}
};
Log.instance(context).useSource(source);
JCCompilationUnit unit = ParserFactory.instance(context)
.newParser(javaInput, true, true, true)
.parseCompilationUnit();
unit.sourcefile = source;
return unit;
}

/** Construct replacements to fix unused imports. */
Expand All @@ -226,70 +264,93 @@ private static RangeMap<Integer, String> buildReplacements(
Set<String> usedNames,
Multimap<String, Range<Integer>> usedInJavadoc) {
RangeMap<Integer, String> replacements = TreeRangeMap.create();
for (JCImport importTree : unit.getImports()) {
String simpleName = getSimpleName(importTree);
if (!isUnused(unit, usedNames, usedInJavadoc, importTree, simpleName)) {
continue;
}
// delete the import
int endPosition = importTree.getEndPosition(unit.endPositions);
endPosition = Math.max(CharMatcher.isNot(' ').indexIn(contents, endPosition), endPosition);
String sep = Newlines.guessLineSeparator(contents);
if (endPosition + sep.length() < contents.length()
&& contents.subSequence(endPosition, endPosition + sep.length())
.toString()
.equals(sep)) {
int size = unit.getImports().size();
unit.getImports().stream()
.filter(importTree -> isUnused(
unit,
usedNames,
usedInJavadoc,
importTree,
getQualifiedIdentifier(importTree).getIdentifier().toString()))
.forEach(importTree -> replacements.put(
Range.closedOpen(
importTree.getStartPosition(),
calculateEndPosition(
contents,
importTree,
unit,
Newlines.guessLineSeparator(contents),
size,
size > 0 ? unit.getImports().get(size - 1) : null)),
""));

return replacements;
}

private static int calculateEndPosition(
String contents,
JCTree importTree,
JCCompilationUnit unit,
String sep,
int size,
@Nullable JCTree lastImport) {
int endPosition = importTree.getEndPosition(unit.endPositions);
endPosition = max(CharMatcher.isNot(' ').indexIn(contents, endPosition), endPosition);
if (endPosition + sep.length() < contents.length()
&& contents.subSequence(endPosition, endPosition + sep.length())
.toString()
.equals(sep)) {
endPosition += sep.length();
} else if ((size == 1 || importTree != lastImport) && !checkForEmptyLineAfter(contents, endPosition, sep)) {
while (endPosition + sep.length() <= contents.length()
&& contents.regionMatches(endPosition, sep, 0, sep.length())) {
endPosition += sep.length();
}
replacements.put(Range.closedOpen(importTree.getStartPosition(), endPosition), "");
}
return replacements;
return endPosition;
}

private static String getSimpleName(ImportTree importTree) {
return importTree.getQualifiedIdentifier() instanceof JCIdent
? ((JCIdent) importTree.getQualifiedIdentifier()).getName().toString()
: ((JCFieldAccess) importTree.getQualifiedIdentifier())
.getIdentifier()
.toString();
private static boolean checkForEmptyLineAfter(String contents, int endPosition, String sep) {
return endPosition + sep.length() * 2 <= contents.length()
&& contents.substring(endPosition, endPosition + sep.length() * 2)
.equals(sep + sep);
}

private static boolean isUnused(
JCCompilationUnit unit,
Set<String> usedNames,
Multimap<String, Range<Integer>> usedInJavadoc,
ImportTree importTree,
JCTree importTree,
String simpleName) {
String qualifier = ((JCFieldAccess) importTree.getQualifiedIdentifier())
.getExpression()
.toString();
JCFieldAccess qualifiedIdentifier = getQualifiedIdentifier(importTree);
String qualifier = qualifiedIdentifier.getExpression().toString();
if (qualifier.equals("java.lang")) {
return true;
}
if (usedNames.contains(simpleName)) {
return false;
}
if (unit.getPackageName() != null && unit.getPackageName().toString().equals(qualifier)) {
return true;
}
if (importTree.getQualifiedIdentifier() instanceof JCFieldAccess
&& ((JCFieldAccess) importTree.getQualifiedIdentifier())
.getIdentifier()
.contentEquals("*")) {
if (qualifiedIdentifier.getIdentifier().contentEquals("*") && !((JCImport) importTree).isStatic()) {
return false;
}
return !usedInJavadoc.containsKey(simpleName);
}

if (usedNames.contains(simpleName)) {
return false;
private static JCFieldAccess getQualifiedIdentifier(JCTree importTree) {
// Use reflection because the return type is JCTree in some versions and JCFieldAccess in others
try {
return (JCFieldAccess)
JCImport.class.getMethod("getQualifiedIdentifier").invoke(importTree);
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
if (usedInJavadoc.containsKey(simpleName)) {
return false;
}
return true;
}

/** Applies the replacements to the given source, and re-format any edited javadoc. */
private static String applyReplacements(String source, RangeMap<Integer, String> replacements) {
// save non-empty fixed ranges for reformatting after fixes are applied
RangeSet<Integer> fixedRanges = TreeRangeSet.create();

// Apply the fixes in increasing order, adjusting ranges to account for
// earlier fixes that change the length of the source. The output ranges are
// needed so we can reformat fixed regions, otherwise the fixes could just
Expand All @@ -299,14 +360,11 @@ private static String applyReplacements(String source, RangeMap<Integer, String>
for (Map.Entry<Range<Integer>, String> replacement :
replacements.asMapOfRanges().entrySet()) {
Range<Integer> range = replacement.getKey();
String replaceWith = replacement.getValue();
int start = offset + range.lowerEndpoint();
int end = offset + range.upperEndpoint();
sb.replace(start, end, replaceWith);
if (!replaceWith.isEmpty()) {
fixedRanges.add(Range.closedOpen(start, end));
if (replacement.getValue().isBlank()) {
String replaceWith = replacement.getValue();
sb.replace(offset + range.lowerEndpoint(), offset + range.upperEndpoint(), replaceWith);
offset += replaceWith.length() - (range.upperEndpoint() - range.lowerEndpoint());
}
offset += replaceWith.length() - (range.upperEndpoint() - range.lowerEndpoint());
}
return sb.toString();
}
Expand Down
Loading