Commit 568b1161 authored by Mikhail Golubev's avatar Mikhail Golubev
Browse files

Fix PIEAE in PyExtractMethodTest

They were caused by the fact that we first replace selected fragment
with method call and then use invalided elements of the same fragment
to find its duplicates.
I split ExtractMethodHelper#processDuplicates() into two methods:
collectDuplicates() that finds duplicates before the substitution is
performed and replaceDuplicatesWithPrompt() that handles user
notification and replacing found occurrences afterwards.
The same way this refactoring is implemented for Java sources
(see ExtractMethodHandler.invokeOnElements()).
parent 6ccd1fd3
Showing with 72 additions and 19 deletions
+72 -19
...@@ -19,6 +19,7 @@ import com.intellij.codeInsight.highlighting.HighlightManager; ...@@ -19,6 +19,7 @@ import com.intellij.codeInsight.highlighting.HighlightManager;
import com.intellij.find.FindManager; import com.intellij.find.FindManager;
import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.application.ApplicationNamesInfo; import com.intellij.openapi.application.ApplicationNamesInfo;
import com.intellij.openapi.application.ReadAction;
import com.intellij.openapi.command.CommandProcessor; import com.intellij.openapi.command.CommandProcessor;
import com.intellij.openapi.editor.Editor; import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.LogicalPosition; import com.intellij.openapi.editor.LogicalPosition;
...@@ -76,6 +77,45 @@ public class ExtractMethodHelper { ...@@ -76,6 +77,45 @@ public class ExtractMethodHelper {
}); });
} }
/**
* Finds duplicates of the code fragment specified in the finder in given scopes.
*
* @param finder finder object to seek for duplicates
* @param searchScopes scopes where to look them in
* @param generatedMethod new method that should be excluded from the search
* @return list of duplicate code fragments discovered
* @see #replaceDuplicatesWithPrompt(List, PsiElement, Editor, Consumer)
*/
@NotNull
public static List<SimpleMatch> collectDuplicates(@NotNull SimpleDuplicatesFinder finder,
@NotNull List<PsiElement> searchScopes,
@NotNull PsiElement generatedMethod) {
if (ApplicationManager.getApplication().isUnitTestMode()) {
return finder.findDuplicates(searchScopes, generatedMethod);
}
return ReadAction.compute(() -> finder.findDuplicates(searchScopes, generatedMethod));
}
/**
* Notifies user about found duplicates and then highlights each of them in the editor and asks user how to proceed.
*
* @param duplicates discovered duplicates of extracted code fragment
* @param replacement generated expression or statement that contains invocation of the new method
* @param editor instance of editor where refactoring is performed
* @param replacer strategy of substituting each duplicate occurence with the replacement fragment
* @see #collectDuplicates(SimpleDuplicatesFinder, List, PsiElement)
*/
public static void replaceDuplicatesWithPrompt(@NotNull List<SimpleMatch> duplicates,
@NotNull PsiElement replacement,
@NotNull Editor editor,
@NotNull Consumer<Pair<SimpleMatch, PsiElement>> replacer) {
if (ApplicationManager.getApplication().isUnitTestMode()) {
replaceDuplicates(replacement, editor, replacer, duplicates);
}
ApplicationManager.getApplication().invokeLater(() -> replaceDuplicates(replacement, editor, replacer, duplicates));
}
private static void replaceDuplicates(PsiElement callElement, private static void replaceDuplicates(PsiElement callElement,
Editor editor, Editor editor,
Consumer<Pair<SimpleMatch, PsiElement>> replacer, Consumer<Pair<SimpleMatch, PsiElement>> replacer,
......
...@@ -22,10 +22,11 @@ import java.util.Set; ...@@ -22,10 +22,11 @@ import java.util.Set;
* User : ktisha * User : ktisha
*/ */
public class SimpleDuplicatesFinder { public class SimpleDuplicatesFinder {
private static final Key<PsiElement> PARAMETER = Key.create("PARAMETER");
protected PsiElement myReplacement; protected PsiElement myReplacement;
private final ArrayList<PsiElement> myPattern; private final ArrayList<PsiElement> myPattern;
private final Set<String> myParameters; private final Set<String> myParameters;
public static final Key<PsiElement> PARAMETER = Key.create("PARAMETER");
private final Collection<String> myOutputVariables; private final Collection<String> myOutputVariables;
@Deprecated @Deprecated
......
...@@ -40,7 +40,6 @@ import com.intellij.refactoring.rename.RenameUtil; ...@@ -40,7 +40,6 @@ import com.intellij.refactoring.rename.RenameUtil;
import com.intellij.refactoring.util.AbstractVariableData; import com.intellij.refactoring.util.AbstractVariableData;
import com.intellij.refactoring.util.CommonRefactoringUtil; import com.intellij.refactoring.util.CommonRefactoringUtil;
import com.intellij.usageView.UsageInfo; import com.intellij.usageView.UsageInfo;
import com.intellij.util.Consumer;
import com.intellij.util.Function; import com.intellij.util.Function;
import com.intellij.util.IncorrectOperationException; import com.intellij.util.IncorrectOperationException;
import com.intellij.util.containers.ContainerUtil; import com.intellij.util.containers.ContainerUtil;
...@@ -168,14 +167,17 @@ public class PyExtractMethodUtil { ...@@ -168,14 +167,17 @@ public class PyExtractMethodUtil {
final PyFunction function1 = generator.createFromText(languageLevel, PyFunction.class, builder.toString()); final PyFunction function1 = generator.createFromText(languageLevel, PyFunction.class, builder.toString());
PsiElement callElement = function1.getStatementList().getStatements()[0]; PsiElement callElement = function1.getStatementList().getStatements()[0];
// replace statements with call
callElement = replaceElements(elementsRange, callElement);
// Both statements are used in finder, so should be valid at this moment // Both statements are used in finder, so should be valid at this moment
PyPsiUtils.assertValid(statement1); PyPsiUtils.assertValid(statement1);
PyPsiUtils.assertValid(statement2); PyPsiUtils.assertValid(statement2);
final List<SimpleMatch> duplicates = collectDuplicates(finder, statement1, generatedMethod);
// replace statements with call
callElement = replaceElements(elementsRange, callElement);
callElement = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(callElement); callElement = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(callElement);
if (callElement != null) { if (callElement != null) {
processDuplicates(callElement, generatedMethod, finder, editor); processDuplicates(duplicates, callElement, editor);
} }
// Set editor // Set editor
...@@ -188,13 +190,19 @@ public class PyExtractMethodUtil { ...@@ -188,13 +190,19 @@ public class PyExtractMethodUtil {
}), PyBundle.message("refactoring.extract.method"), null); }), PyBundle.message("refactoring.extract.method"), null);
} }
private static void processDuplicates(@NotNull final PsiElement callElement, @NotNull
@NotNull final PyFunction generatedMethod, private static List<SimpleMatch> collectDuplicates(@NotNull SimpleDuplicatesFinder finder,
@NotNull final SimpleDuplicatesFinder finder, @NotNull PsiElement originalScopeAnchor,
@NotNull final Editor editor) { @NotNull PyFunction generatedMethod) {
final ScopeOwner owner = ScopeUtil.getScopeOwner(callElement); final List<PsiElement> scopes = collectScopes(originalScopeAnchor, generatedMethod);
if (owner instanceof PsiFile) return; return ExtractMethodHelper.collectDuplicates(finder, scopes, generatedMethod);
final List<PsiElement> scope = new ArrayList<PsiElement>(); }
@NotNull
private static List<PsiElement> collectScopes(@NotNull PsiElement anchor, @NotNull PyFunction generatedMethod) {
final ScopeOwner owner = ScopeUtil.getScopeOwner(anchor);
if (owner instanceof PsiFile) return Collections.emptyList();
final List<PsiElement> scope = new ArrayList<>();
if (owner instanceof PyFunction) { if (owner instanceof PyFunction) {
scope.add(owner); scope.add(owner);
final PyClass containingClass = ((PyFunction)owner).getContainingClass(); final PyClass containingClass = ((PyFunction)owner).getContainingClass();
...@@ -206,9 +214,13 @@ public class PyExtractMethodUtil { ...@@ -206,9 +214,13 @@ public class PyExtractMethodUtil {
} }
} }
} }
ExtractMethodHelper.processDuplicates(callElement, generatedMethod, scope, finder, editor, return scope;
pair -> replaceElements(pair.first, pair.second.copy()) }
);
private static void processDuplicates(@NotNull List<SimpleMatch> duplicates,
@NotNull PsiElement replacement,
@NotNull Editor editor) {
ExtractMethodHelper.replaceDuplicatesWithPrompt(duplicates, replacement, editor, pair -> replaceElements(pair.first, pair.second.copy()));
} }
private static void processGlobalWrites(@NotNull final PyFunction function, @NotNull final PyCodeFragment fragment) { private static void processGlobalWrites(@NotNull final PyFunction function, @NotNull final PyCodeFragment fragment) {
...@@ -337,12 +349,14 @@ public class PyExtractMethodUtil { ...@@ -337,12 +349,14 @@ public class PyExtractMethodUtil {
callElement = ((PyExpressionStatement)generated).getExpression(); callElement = ((PyExpressionStatement)generated).getExpression();
} }
PyPsiUtils.assertValid(expression);
final List<SimpleMatch> duplicates = collectDuplicates(finder, expression, generatedMethod);
// replace statements with call // replace statements with call
if (callElement != null) { if (callElement != null) {
callElement = PyReplaceExpressionUtil.replaceExpression(expression, callElement); callElement = PyReplaceExpressionUtil.replaceExpression(expression, callElement);
} }
if (callElement != null) { if (callElement != null) {
processDuplicates(callElement, generatedMethod, finder, editor); processDuplicates(duplicates, callElement, editor);
} }
// Set editor // Set editor
setSelectionAndCaret(editor, callElement); setSelectionAndCaret(editor, callElement);
...@@ -394,9 +408,7 @@ public class PyExtractMethodUtil { ...@@ -394,9 +408,7 @@ public class PyExtractMethodUtil {
for (PyExpression arg : argumentList.getArguments()) { for (PyExpression arg : argumentList.getArguments()) {
final String argText = arg.getText(); final String argText = arg.getText();
if (argText != null && keys.contains(argText)) { if (argText != null && keys.contains(argText)) {
arg.replace(generator.createExpressionFromText( arg.replace(generator.createExpressionFromText(LanguageLevel.forElement(callElement), changedParameters.get(argText)));
LanguageLevel.forElement(callElement),
changedParameters.get(argText)));
} }
} }
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment