Commit 965ea530 authored by Jinseong Jeon's avatar Jinseong Jeon
Browse files

FIR/UAST: support multi-resolution to compound assignments

^KTIJ-13815 Fixed
parent 7e9c8b99
Showing with 236 additions and 37 deletions
+236 -37
......@@ -3,16 +3,19 @@
package org.jetbrains.uast.kotlin
import com.intellij.psi.PsiMethod
import com.intellij.psi.ResolveResult
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.KtBinaryExpression
import org.jetbrains.uast.*
import org.jetbrains.uast.kotlin.internal.getResolveResultVariants
@ApiStatus.Internal
class KotlinUBinaryExpression(
override val sourcePsi: KtBinaryExpression,
givenParent: UElement?
) : KotlinAbstractUExpression(givenParent), UBinaryExpression, KotlinUElementWithType, KotlinEvaluatableUElement {
) : KotlinAbstractUExpression(givenParent), UBinaryExpression, KotlinUElementWithType, KotlinEvaluatableUElement,
UMultiResolvable {
companion object {
val BITWISE_OPERATORS = mapOf(
......@@ -38,16 +41,16 @@ class KotlinUBinaryExpression(
}
override fun resolveOperator(): PsiMethod? {
baseResolveProviderService.resolveCall(sourcePsi)?.let { return it }
return when (sourcePsi.operationToken) {
KtTokens.EQ -> {
// array[index1, index2, ...] = v
(leftOperand as? UArrayAccessExpression)?.resolve() as? PsiMethod
}
else -> null
}
// array[index1, index2, ...] = v or ... += v
// NB: In the latter case, array getter is accessed first, hence the resolution points to that.
// To see if this binary operator can be resolved to array setter, use [UMultiResolvable#multiResolve] below.
((leftOperand as? UArrayAccessExpression)?.resolve() as? PsiMethod)?.let { return it }
return baseResolveProviderService.resolveCall(sourcePsi)
}
override fun multiResolve(): Iterable<ResolveResult> =
getResolveResultVariants(baseResolveProviderService, sourcePsi)
override val operator: UastBinaryOperator
get() = when (sourcePsi.operationToken) {
KtTokens.EQ -> UastBinaryOperator.ASSIGN
......
......@@ -3,18 +3,19 @@
package org.jetbrains.uast.kotlin
import com.intellij.psi.PsiMethod
import com.intellij.psi.ResolveResult
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.KtPostfixExpression
import org.jetbrains.uast.*
import org.jetbrains.uast.kotlin.internal.DelegatedMultiResolve
import org.jetbrains.uast.kotlin.internal.getResolveResultVariants
@ApiStatus.Internal
class KotlinUPostfixExpression(
override val sourcePsi: KtPostfixExpression,
givenParent: UElement?
) : KotlinAbstractUExpression(givenParent), UPostfixExpression, KotlinUElementWithType, KotlinEvaluatableUElement,
UResolvable, DelegatedMultiResolve {
UResolvable, UMultiResolvable {
override val operand by lz {
baseResolveProviderService.baseKotlinConverter.convertOrEmpty(sourcePsi.baseExpression, this)
}
......@@ -36,4 +37,7 @@ class KotlinUPostfixExpression(
KtTokens.EXCLEXCL -> operand.tryResolve() as? PsiMethod
else -> null
}
override fun multiResolve(): Iterable<ResolveResult> =
getResolveResultVariants(baseResolveProviderService, sourcePsi)
}
......@@ -3,19 +3,19 @@
package org.jetbrains.uast.kotlin
import com.intellij.psi.PsiMethod
import com.intellij.psi.ResolveResult
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.KtPrefixExpression
import org.jetbrains.uast.UElement
import org.jetbrains.uast.UIdentifier
import org.jetbrains.uast.UPrefixExpression
import org.jetbrains.uast.UastPrefixOperator
import org.jetbrains.uast.*
import org.jetbrains.uast.kotlin.internal.getResolveResultVariants
@ApiStatus.Internal
class KotlinUPrefixExpression(
override val sourcePsi: KtPrefixExpression,
givenParent: UElement?
) : KotlinAbstractUExpression(givenParent), UPrefixExpression, KotlinUElementWithType, KotlinEvaluatableUElement {
) : KotlinAbstractUExpression(givenParent), UPrefixExpression, KotlinUElementWithType, KotlinEvaluatableUElement,
UMultiResolvable {
override val operand by lz {
baseResolveProviderService.baseKotlinConverter.convertOrEmpty(sourcePsi.baseExpression, this)
}
......@@ -34,4 +34,7 @@ class KotlinUPrefixExpression(
KtTokens.MINUSMINUS -> UastPrefixOperator.DEC
else -> UastPrefixOperator.UNKNOWN
}
override fun multiResolve(): Iterable<ResolveResult> =
getResolveResultVariants(baseResolveProviderService, sourcePsi)
}
......@@ -860,8 +860,9 @@ interface UastResolveApiFixtureTestBase : UastPluginSelection {
myFixture.configureByText(
"main.kt", """
fun foo(array: SparseArray<String>) {
array[42L] = "forty two"
array[42L] = "forty"
val y = array[42]
array[42L] += " two"
}
""".trimIndent()
)
......@@ -891,6 +892,16 @@ interface UastResolveApiFixtureTestBase : UastPluginSelection {
TestCase.assertEquals(1, getResolved.parameterList.parameters.size)
TestCase.assertEquals("int", getResolved.parameterList.parameters[0].type.canonicalText)
TestCase.assertEquals("E", getResolved.returnType?.canonicalText)
val augmented = uFile.findElementByTextFromPsi<UBinaryExpression>("array[42L] +=", strict = false)
.orFail("cant convert to UBinaryExpression")
val augmentedResolved = augmented.resolveOperator()
.orFail("cant resolve from $augmented")
// NB: not exactly same as above one, which is `E get(int)`, whereas this one is `E get(long)`
TestCase.assertEquals(getResolved.name, augmentedResolved.name)
TestCase.assertEquals(1, augmentedResolved.parameterList.parameters.size)
TestCase.assertEquals("long", augmentedResolved.parameterList.parameters[0].type.canonicalText)
TestCase.assertEquals("E", augmentedResolved.returnType?.canonicalText)
}
fun checkOperatorOverloads(myFixture: JavaCodeInsightTestFixture) {
......@@ -980,6 +991,102 @@ interface UastResolveApiFixtureTestBase : UastPluginSelection {
TestCase.assertEquals("Point", plusPoint?.containingClass?.name)
}
fun checkOperatorMultiResolvable(myFixture: JavaCodeInsightTestFixture) {
myFixture.addClass(
"""
public class SparseArray<E> {
private Map<Long, E> map = new HashMap<Long, E>();
public void set(int key, E value) { map.put(key, value); }
public void set(long key, E value) { map.put(key, value); }
public E get(int key) { return map.get(key); }
public E get(long key) { return map.get(key); }
}
""".trimIndent()
)
myFixture.configureByText(
"main.kt", """
data class Point(val x: Int, val y: Int) {
operator fun inc() = Point(x + 1, y + 1)
}
operator fun Point.dec() = Point(x - 1, y - 1)
fun test(array: SparseArray<String>) {
var i = Point(0, 0)
i++
i--
++i
--i
array[42L] = "forty"
array[42L] += " two"
}
""".trimIndent()
)
val uFile = myFixture.file.toUElement()!!
val iPlusPlus = uFile.findElementByTextFromPsi<UPostfixExpression>("i++", strict = false)
.orFail("cant convert to UPostfixExpression")
val iPlusPlusResolvedDeclarations = (iPlusPlus as UMultiResolvable).multiResolve()
val iPlusPlusResolvedDeclarationsStrings = iPlusPlusResolvedDeclarations.map { it.element?.text ?: "<null>" }
assertContainsElements(
iPlusPlusResolvedDeclarationsStrings,
"var i = Point(0, 0)",
"operator fun inc() = Point(x + 1, y + 1)",
)
val iMinusMinus = uFile.findElementByTextFromPsi<UPostfixExpression>("i--", strict = false)
.orFail("cant convert to UPostfixExpression")
val iMinusMinusResolvedDeclarations = (iMinusMinus as UMultiResolvable).multiResolve()
val iMinusMinusResolvedDeclarationsStrings = iMinusMinusResolvedDeclarations.map { it.element?.text ?: "<null>" }
assertContainsElements(
iMinusMinusResolvedDeclarationsStrings,
"var i = Point(0, 0)",
"operator fun Point.dec() = Point(x - 1, y - 1)",
)
val plusPlusI = uFile.findElementByTextFromPsi<UPrefixExpression>("++i", strict = false)
.orFail("cant convert to UPrefixExpression")
val plusPlusIResolvedDeclarations = (plusPlusI as UMultiResolvable).multiResolve()
val plusPlusIResolvedDeclarationsStrings = plusPlusIResolvedDeclarations.map { it.element?.text ?: "<null>" }
assertContainsElements(
plusPlusIResolvedDeclarationsStrings,
"var i = Point(0, 0)",
"operator fun inc() = Point(x + 1, y + 1)",
)
val minusMinusI = uFile.findElementByTextFromPsi<UPrefixExpression>("--i", strict = false)
.orFail("cant convert to UPrefixExpression")
val minusMinusIResolvedDeclarations = (minusMinusI as UMultiResolvable).multiResolve()
val minusMinusIResolvedDeclarationsStrings = minusMinusIResolvedDeclarations.map { it.element?.text ?: "<null>" }
assertContainsElements(
minusMinusIResolvedDeclarationsStrings,
"var i = Point(0, 0)",
"operator fun Point.dec() = Point(x - 1, y - 1)",
)
val aEq = uFile.findElementByTextFromPsi<UBinaryExpression>("array[42L] =", strict = false)
.orFail("cant convert to UBinaryExpression")
val aEqResolvedDeclarations = (aEq as UMultiResolvable).multiResolve()
val aEqResolvedDeclarationsStrings = aEqResolvedDeclarations.map { it.element?.text ?: "<null>" }
assertContainsElements(
aEqResolvedDeclarationsStrings,
"public void set(long key, E value) { map.put(key, value); }",
)
val aPlusEq = uFile.findElementByTextFromPsi<UBinaryExpression>("array[42L] +=", strict = false)
.orFail("cant convert to UBinaryExpression")
val aPlusEqResolvedDeclarations = (aPlusEq as UMultiResolvable).multiResolve()
val aPlusEqResolvedDeclarationsStrings = aPlusEqResolvedDeclarations.map { it.element?.text ?: "<null>" }
assertContainsElements(
aPlusEqResolvedDeclarationsStrings,
"public E get(long key) { return map.get(key); }",
"public void set(long key, E value) { map.put(key, value); }",
)
}
fun checkResolveSyntheticJavaPropertyAccessor(myFixture: JavaCodeInsightTestFixture) {
myFixture.addClass(
"""public class X {
......
......@@ -188,15 +188,23 @@ interface FirKotlinUastResolveProviderService : BaseKotlinUastResolveProviderSer
override fun getReferenceVariants(ktExpression: KtExpression, nameHint: String): Sequence<PsiElement> {
analyzeForUast(ktExpression) {
return ktExpression.collectCallCandidates().asSequence().mapNotNull {
when (val candidate = it.candidate) {
is KtFunctionCall<*> -> {
toPsiMethod(candidate.partiallyAppliedSymbol.symbol, ktExpression)
}
is KtCompoundAccessCall -> {
toPsiMethod(candidate.compoundAccess.operationPartiallyAppliedSymbol.symbol, ktExpression)
return sequence {
ktExpression.collectCallCandidates().forEach { candidateInfo ->
when (val candidate = candidateInfo.candidate) {
is KtFunctionCall<*> -> {
toPsiMethod(candidate.partiallyAppliedSymbol.symbol, ktExpression)?.let { yield(it) }
}
is KtCompoundVariableAccessCall -> {
psiForUast(candidate.partiallyAppliedSymbol.symbol, ktExpression.project)?.let { yield(it) }
toPsiMethod(candidate.compoundAccess.operationPartiallyAppliedSymbol.symbol, ktExpression)?.let { yield(it) }
}
is KtCompoundArrayAccessCall -> {
toPsiMethod(candidate.getPartiallyAppliedSymbol.symbol, ktExpression)?.let { yield(it) }
toPsiMethod(candidate.setPartiallyAppliedSymbol.symbol, ktExpression)?.let { yield(it) }
toPsiMethod(candidate.compoundAccess.operationPartiallyAppliedSymbol.symbol, ktExpression)?.let { yield(it) }
}
else -> {}
}
else -> null
}
}
}
......
......@@ -168,8 +168,8 @@ class FirUastResolveApiFixtureTest : KotlinLightCodeInsightFixtureTestCase(), Ua
doCheck("OperatorOverloads", ::checkOperatorOverloads)
}
fun testResolveStaticImportFromObject() {
doCheck("ResolveStaticImportFromObject", ::checkResolveStaticImportFromObject)
fun testOperatorMultiResolvable() {
doCheck("OperatorMultiResolvable", ::checkOperatorMultiResolvable)
}
fun testResolveSyntheticJavaPropertyAccessor() {
......@@ -180,6 +180,10 @@ class FirUastResolveApiFixtureTest : KotlinLightCodeInsightFixtureTestCase(), Ua
doCheck("ResolveKotlinPropertyAccessor", ::checkResolveKotlinPropertyAccessor)
}
fun testResolveStaticImportFromObject() {
doCheck("ResolveStaticImportFromObject", ::checkResolveStaticImportFromObject)
}
fun testResolveToSubstituteOverride() {
doCheck("ResolveToSubstituteOverride", ::checkResolveToSubstituteOverride)
}
......
......@@ -53,9 +53,15 @@ class IdeaKotlinUastResolveProviderService : KotlinUastResolveProviderService {
override fun getReferenceVariants(ktExpression: KtExpression, nameHint: String): Sequence<PsiElement> {
val resolutionFacade = ktExpression.getResolutionFacade()
val bindingContext = ktExpression.safeAnalyzeNonSourceRootCode(resolutionFacade)
val call = ktExpression.getCall(bindingContext) ?: return emptySequence()
return call.resolveCandidates(bindingContext, resolutionFacade)
.mapNotNull { resolveToDeclarationImpl(ktExpression, it.candidateDescriptor) }
.asSequence()
return sequence {
// Use logic (shared with CLI) about handling compound assignments
yieldAll(super.getReferenceVariants(ktExpression, nameHint))
// Then, look for other candidates with a name hint
val call = ktExpression.getCall(bindingContext) ?: return@sequence
call.resolveCandidates(bindingContext, resolutionFacade)
.forEach {resolvedCall ->
resolveToDeclarationImpl(ktExpression, resolvedCall.candidateDescriptor)?.let { yield(it) }
}
}
}
}
......@@ -10,6 +10,7 @@ import org.jetbrains.kotlin.codegen.state.KotlinTypeMapper
import org.jetbrains.kotlin.config.LanguageVersionSettings
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.idea.util.actionUnderSafeAnalyzeBlock
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.name.FqNameUnsafe
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.getParentOfType
......@@ -174,6 +175,57 @@ interface KotlinUastResolveProviderService : BaseKotlinUastResolveProviderServic
return psiElement.actionUnderSafeAnalyzeBlock({ psiElement.annotations }, { emptyArray() })
}
override fun getReferenceVariants(ktExpression: KtExpression, nameHint: String): Sequence<PsiElement> {
val unwrappedPsi = KtPsiUtil.deparenthesize(ktExpression) ?: ktExpression
return sequence {
when (unwrappedPsi) {
is KtUnaryExpression -> {
if (unwrappedPsi.operationToken in KtTokens.INCREMENT_AND_DECREMENT ||
unwrappedPsi.operationToken == KtTokens.EXCLEXCL
) {
// E.g., `i++` -> access to `i`
unwrappedPsi.baseExpression?.let { resolveToDeclaration(it) }?.let { yield(it) }
}
// Look for regular function call, e.g., inc() in `i++`
resolveToDeclaration(ktExpression)?.let { yield(it) }
}
is KtBinaryExpression -> {
val left = unwrappedPsi.left
when (unwrappedPsi.operationToken) {
KtTokens.EQ -> {
if (left is KtArrayAccessExpression) {
// E.g., `array[...] = ...` -> access to `array[...]`, i.e., (overloaded) setter
val context = left.analyze()
val resolvedSetCall = context[BindingContext.INDEXED_LVALUE_SET, left]
resolvedSetCall?.resultingDescriptor?.let { resolveToPsiMethod(unwrappedPsi, it) }?.let { yield(it) }
} else {
// E.g. `i = ...` -> access to `i`
left?.let { resolveToDeclaration(it) }?.let { yield(it) }
}
}
in KtTokens.AUGMENTED_ASSIGNMENTS -> {
if (left is KtArrayAccessExpression) {
// E.g., `array[...] += ...` -> access to `array[...]`, i.e., (overloaded) getter and setter
val context = left.analyze()
val resolvedGetCall = context[BindingContext.INDEXED_LVALUE_GET, left]
resolvedGetCall?.resultingDescriptor?.let { resolveToPsiMethod(unwrappedPsi, it) }?.let { yield(it) }
val resolvedSetCall = context[BindingContext.INDEXED_LVALUE_SET, left]
resolvedSetCall?.resultingDescriptor?.let { resolveToPsiMethod(unwrappedPsi, it) }?.let { yield(it) }
} else {
// Look for regular function call, e.g., plusAssign() in `i += j`
resolveToDeclaration(ktExpression)?.let { yield(it) }
}
}
else -> {}
}
}
else -> {
// TODO: regular function call resolution?
}
}
}
}
override fun resolveBitwiseOperators(ktBinaryExpression: KtBinaryExpression): UastBinaryOperator {
val other = UastBinaryOperator.OTHER
val ref = ktBinaryExpression.operationReference
......
......@@ -14,7 +14,6 @@ import org.jetbrains.kotlin.context.ProjectContext
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.metadata.jvm.deserialization.JvmProtoBufUtil
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtExpression
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.BindingTrace
......@@ -47,9 +46,6 @@ class CliKotlinUastResolveProviderService : KotlinUastResolveProviderService {
override fun getLanguageVersionSettings(element: KtElement): LanguageVersionSettings {
return element.project.analysisCompletedHandler?.getLanguageVersionSettings() ?: LanguageVersionSettingsImpl.DEFAULT
}
override fun getReferenceVariants(ktExpression: KtExpression, nameHint: String): Sequence<PsiElement> =
emptySequence() // Not supported
}
class UastAnalysisHandlerExtension : AnalysisHandlerExtension {
......
......@@ -282,8 +282,20 @@ internal fun KotlinULambdaExpression.getFunctionalInterfaceType(): PsiType? {
return sourcePsi.getExpectedType()?.getFunctionalInterfaceType(this, sourcePsi, sourcePsi.typeOwnerKind)
}
internal fun resolveToPsiMethod(context: KtElement): PsiMethod? =
context.getResolvedCall(context.analyze())?.resultingDescriptor?.let { resolveToPsiMethod(context, it) }
internal fun resolveToPsiMethod(ktElement: KtElement): PsiMethod? {
val context = ktElement.analyze()
if (ktElement is KtArrayAccessExpression) {
// Try getter first, e.g., array[...] += v, ... = array[...], or ...(..., array[...], ...)
context[BindingContext.INDEXED_LVALUE_GET, ktElement]?.resultingDescriptor
?.let { resolveToPsiMethod(ktElement, it) }
?.let { return it }
// Then, setter, e.g., array[...] = v
context[BindingContext.INDEXED_LVALUE_SET, ktElement]?.resultingDescriptor
?.let { resolveToPsiMethod(ktElement, it) }
?.let { return it }
}
return ktElement.getResolvedCall(context)?.resultingDescriptor?.let { resolveToPsiMethod(ktElement, it) }
}
internal fun resolveToPsiMethod(
context: KtElement,
......
......@@ -123,6 +123,10 @@ class FE1UastResolveApiFixtureTest : KotlinLightCodeInsightFixtureTestCase(), Ua
checkOperatorOverloads(myFixture)
}
fun testOperatorMultiResolvable() {
checkOperatorMultiResolvable(myFixture)
}
fun testResolveSyntheticJavaPropertyAccessor() {
checkResolveSyntheticJavaPropertyAccessor(myFixture)
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment