diff --git a/CHANGELOG.md b/CHANGELOG.md index 1434d5a09..b6b6efdcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ ### Custom Code Generator +- Add ability to query top-level functions. This allows you write code generators for top-level functions, see #644. + ## [2.4.3] - 2022-12-16 ### Added diff --git a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor.kt b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor.kt index f263c9e65..1309ee0fb 100644 --- a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor.kt +++ b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor.kt @@ -29,6 +29,8 @@ public interface AnvilModuleDescriptor : ModuleDescriptor { public fun getClassAndInnerClassReferences(ktFile: KtFile): List + public fun getTopLevelFunctionReferences(ktFile: KtFile): List + public fun getClassReference(clazz: KtClassOrObject): Psi public fun getClassReference(descriptor: ClassDescriptor): Descriptor @@ -58,3 +60,12 @@ public fun Collection.classAndInnerClassReferences( module.asAnvilModuleDescriptor().getClassAndInnerClassReferences(it) } } + +@ExperimentalAnvilApi +public fun Collection.topLevelFunctionReferences( + module: ModuleDescriptor +): Sequence { + return asSequence().flatMap { + module.asAnvilModuleDescriptor().getTopLevelFunctionReferences(it) + } +} diff --git a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/FunctionReference.kt b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/FunctionReference.kt index 09834d23c..0d762b907 100644 --- a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/FunctionReference.kt +++ b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/FunctionReference.kt @@ -1,7 +1,6 @@ package com.squareup.anvil.compiler.internal.reference import com.squareup.anvil.annotations.ExperimentalAnvilApi -import com.squareup.anvil.compiler.api.AnvilCompilationException import com.squareup.anvil.compiler.internal.reference.FunctionReference.Descriptor import com.squareup.anvil.compiler.internal.reference.FunctionReference.Psi import com.squareup.anvil.compiler.internal.reference.Visibility.INTERNAL @@ -30,29 +29,18 @@ import kotlin.LazyThreadSafetyMode.NONE * [FunctionDescriptor] references, to streamline parsing. */ @ExperimentalAnvilApi -public sealed class FunctionReference : AnnotatedReference { +public sealed class FunctionReference : AnnotatedReference, FunctionalReference { - public abstract val fqName: FqName public abstract val declaringClass: ClassReference - public val name: String get() = fqName.shortName().asString() - public val module: AnvilModuleDescriptor get() = declaringClass.module - - public abstract val parameters: List + public override val module: AnvilModuleDescriptor get() = declaringClass.module protected abstract val returnType: TypeReference? - public fun returnTypeOrNull(): TypeReference? = returnType - public fun returnType(): TypeReference = returnType - ?: throw AnvilCompilationExceptionFunctionReference( - functionReference = this, - message = "Unable to get the return type for function $fqName." - ) + public override fun returnTypeOrNull(): TypeReference? = returnType public abstract fun isAbstract(): Boolean public abstract fun isConstructor(): Boolean - public abstract fun visibility(): Visibility - public fun resolveGenericReturnTypeOrNull( implementingClass: ClassReference ): ClassReference? { @@ -63,8 +51,8 @@ public sealed class FunctionReference : AnnotatedReference { public fun resolveGenericReturnType(implementingClass: ClassReference): ClassReference = resolveGenericReturnTypeOrNull(implementingClass) - ?: throw AnvilCompilationExceptionFunctionReference( - functionReference = this, + ?: throw AnvilCompilationExceptionFunctionalReference( + functionalReference = this, message = "Unable to resolve return type for function $fqName with the implementing " + "class ${implementingClass.fqName}." ) @@ -73,7 +61,7 @@ public sealed class FunctionReference : AnnotatedReference { override fun equals(other: Any?): Boolean { if (this === other) return true - if (other !is ClassReference) return false + if (other !is FunctionReference) return false if (fqName != other.fqName) return false @@ -85,10 +73,10 @@ public sealed class FunctionReference : AnnotatedReference { } public class Psi internal constructor( - public val function: KtFunction, + public override val function: KtFunction, override val declaringClass: ClassReference.Psi, override val fqName: FqName - ) : FunctionReference() { + ) : FunctionReference(), FunctionalReference.Psi { override val annotations: List by lazy(NONE) { function.annotationEntries.map { @@ -125,10 +113,10 @@ public sealed class FunctionReference : AnnotatedReference { } public class Descriptor internal constructor( - public val function: FunctionDescriptor, + public override val function: FunctionDescriptor, override val declaringClass: ClassReference.Descriptor, override val fqName: FqName = function.fqNameSafe - ) : FunctionReference() { + ) : FunctionReference(), FunctionalReference.Descriptor { override val annotations: List by lazy(NONE) { function.annotations.map { @@ -144,16 +132,6 @@ public sealed class FunctionReference : AnnotatedReference { function.returnType?.toTypeReference(declaringClass, module) } - internal val overriddenFunctions by lazy(NONE) { - generateSequence( - function.overriddenDescriptors - ) { overriddenFunctions -> - overriddenFunctions - .flatMap { it.overriddenDescriptors } - .takeIf { it.isNotEmpty() } - }.flatten().toList() - } - override fun isAbstract(): Boolean = function.modality == ABSTRACT override fun isConstructor(): Boolean = function is ClassConstructorDescriptor @@ -192,22 +170,3 @@ public fun FunctionDescriptor.toFunctionReference( ): Descriptor { return Descriptor(this, declaringClass) } - -@ExperimentalAnvilApi -@Suppress("FunctionName") -public fun AnvilCompilationExceptionFunctionReference( - functionReference: FunctionReference, - message: String, - cause: Throwable? = null -): AnvilCompilationException = when (functionReference) { - is Psi -> AnvilCompilationException( - element = functionReference.function, - message = message, - cause = cause - ) - is Descriptor -> AnvilCompilationException( - functionDescriptor = functionReference.function, - message = message, - cause = cause - ) -} diff --git a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/FunctionalReference.kt b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/FunctionalReference.kt new file mode 100644 index 000000000..ce0a028f5 --- /dev/null +++ b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/FunctionalReference.kt @@ -0,0 +1,55 @@ +package com.squareup.anvil.compiler.internal.reference + +import com.squareup.anvil.annotations.ExperimentalAnvilApi +import com.squareup.anvil.compiler.api.AnvilCompilationException +import com.squareup.anvil.compiler.internal.reference.FunctionalReference.Descriptor +import com.squareup.anvil.compiler.internal.reference.FunctionalReference.Psi +import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.name.FqName +import org.jetbrains.kotlin.psi.KtFunction + +@ExperimentalAnvilApi +public sealed interface FunctionalReference { + public val fqName: FqName + public val name: String get() = fqName.shortName().asString() + + public val module: AnvilModuleDescriptor + + public val parameters: List + + public fun returnTypeOrNull(): TypeReference? + public fun returnType(): TypeReference = returnTypeOrNull() + ?: throw AnvilCompilationExceptionFunctionalReference( + functionalReference = this, + message = "Unable to get the return type for function $fqName." + ) + + public fun visibility(): Visibility + + public sealed interface Psi : FunctionalReference { + public val function: KtFunction + } + + public sealed interface Descriptor : FunctionalReference { + public val function: FunctionDescriptor + } +} + +@ExperimentalAnvilApi +@Suppress("FunctionName") +public fun AnvilCompilationExceptionFunctionalReference( + functionalReference: FunctionalReference, + message: String, + cause: Throwable? = null +): AnvilCompilationException = when (functionalReference) { + is Psi -> AnvilCompilationException( + element = functionalReference.function, + message = message, + cause = cause + ) + is Descriptor -> AnvilCompilationException( + functionDescriptor = functionalReference.function, + message = message, + cause = cause + ) +} diff --git a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/ParameterReference.kt b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/ParameterReference.kt index d08498865..1a420f0c7 100644 --- a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/ParameterReference.kt +++ b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/ParameterReference.kt @@ -13,11 +13,13 @@ import kotlin.LazyThreadSafetyMode.NONE public sealed class ParameterReference : AnnotatedReference { public abstract val name: String - public abstract val declaringFunction: FunctionReference + public abstract val declaringFunction: FunctionalReference public val module: AnvilModuleDescriptor get() = declaringFunction.module protected abstract val type: TypeReference? + protected abstract val declaringClass: ClassReference? + /** * The type can be null for generic type parameters like `T`. In this case try to resolve the * type with [TypeReference.resolveGenericTypeOrNull]. @@ -66,7 +68,7 @@ public sealed class ParameterReference : AnnotatedReference { public class Psi( public val parameter: KtParameter, - override val declaringFunction: FunctionReference.Psi + override val declaringFunction: FunctionalReference.Psi ) : ParameterReference() { override val name: String = parameter.nameAsSafeName.asString() @@ -77,13 +79,19 @@ public sealed class ParameterReference : AnnotatedReference { } override val type: TypeReference.Psi? by lazy(NONE) { - parameter.typeReference?.toTypeReference(declaringFunction.declaringClass, module) + parameter.typeReference?.toTypeReference(declaringClass, module) } + + override val declaringClass: ClassReference.Psi? + get() = when (declaringFunction) { + is TopLevelFunctionReference.Psi -> null + is FunctionReference.Psi -> declaringFunction.declaringClass + } } public class Descriptor( public val parameter: ValueParameterDescriptor, - override val declaringFunction: FunctionReference.Descriptor + override val declaringFunction: FunctionalReference.Descriptor ) : ParameterReference() { override val name: String = parameter.name.asString() @@ -94,21 +102,27 @@ public sealed class ParameterReference : AnnotatedReference { } override val type: TypeReference.Descriptor? by lazy(NONE) { - parameter.type.toTypeReference(declaringFunction.declaringClass, module) + parameter.type.toTypeReference(declaringClass, module) } + + override val declaringClass: ClassReference.Descriptor? + get() = when (declaringFunction) { + is TopLevelFunctionReference.Descriptor -> null + is FunctionReference.Descriptor -> declaringFunction.declaringClass + } } } @ExperimentalAnvilApi public fun KtParameter.toParameterReference( - declaringFunction: FunctionReference.Psi + declaringFunction: FunctionalReference.Psi ): Psi { return Psi(this, declaringFunction) } @ExperimentalAnvilApi public fun ValueParameterDescriptor.toParameterReference( - declaringFunction: FunctionReference.Descriptor + declaringFunction: FunctionalReference.Descriptor ): Descriptor { return Descriptor(this, declaringFunction) } diff --git a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/TopLevelFunctionReference.kt b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/TopLevelFunctionReference.kt new file mode 100644 index 000000000..1a22783ea --- /dev/null +++ b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/reference/TopLevelFunctionReference.kt @@ -0,0 +1,123 @@ +package com.squareup.anvil.compiler.internal.reference + +import com.squareup.anvil.annotations.ExperimentalAnvilApi +import com.squareup.anvil.compiler.internal.reference.TopLevelFunctionReference.Descriptor +import com.squareup.anvil.compiler.internal.reference.TopLevelFunctionReference.Psi +import com.squareup.anvil.compiler.internal.reference.Visibility.INTERNAL +import com.squareup.anvil.compiler.internal.reference.Visibility.PRIVATE +import com.squareup.anvil.compiler.internal.reference.Visibility.PROTECTED +import com.squareup.anvil.compiler.internal.reference.Visibility.PUBLIC +import com.squareup.anvil.compiler.internal.requireFqName +import org.jetbrains.kotlin.descriptors.DescriptorVisibilities +import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.name.FqName +import org.jetbrains.kotlin.psi.KtFunction +import org.jetbrains.kotlin.psi.psiUtil.visibilityModifierTypeOrDefault +import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe +import kotlin.LazyThreadSafetyMode.NONE + +@ExperimentalAnvilApi +public sealed class TopLevelFunctionReference : AnnotatedReference, FunctionalReference { + + protected abstract val returnType: TypeReference? + + public override fun returnTypeOrNull(): TypeReference? = returnType + + override fun toString(): String = "$fqName()" + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is TopLevelFunctionReference) return false + + if (fqName != other.fqName) return false + + return true + } + + override fun hashCode(): Int { + return fqName.hashCode() + } + + public class Psi internal constructor( + public override val function: KtFunction, + override val fqName: FqName, + override val module: AnvilModuleDescriptor, + ) : TopLevelFunctionReference(), FunctionalReference.Psi { + + override val annotations: List by lazy(NONE) { + function.annotationEntries.map { + it.toAnnotationReference(declaringClass = null, module) + } + } + + override val returnType: TypeReference.Psi? by lazy(NONE) { + function.typeReference?.toTypeReference(declaringClass = null, module) + } + + override val parameters: List by lazy(NONE) { + function.valueParameters.map { it.toParameterReference(this) } + } + + override fun visibility(): Visibility { + return when (val visibility = function.visibilityModifierTypeOrDefault()) { + KtTokens.PUBLIC_KEYWORD -> PUBLIC + KtTokens.INTERNAL_KEYWORD -> INTERNAL + KtTokens.PROTECTED_KEYWORD -> PROTECTED + KtTokens.PRIVATE_KEYWORD -> PRIVATE + else -> throw AnvilCompilationExceptionFunctionalReference( + functionalReference = this, + message = "Couldn't get visibility $visibility for function $fqName." + ) + } + } + } + + public class Descriptor internal constructor( + public override val function: FunctionDescriptor, + override val fqName: FqName = function.fqNameSafe, + override val module: AnvilModuleDescriptor, + ) : TopLevelFunctionReference(), FunctionalReference.Descriptor { + + override val annotations: List by lazy(NONE) { + function.annotations.map { + it.toAnnotationReference(declaringClass = null, module) + } + } + + override val parameters: List by lazy(NONE) { + function.valueParameters.map { it.toParameterReference(this) } + } + + override val returnType: TypeReference.Descriptor? by lazy(NONE) { + function.returnType?.toTypeReference(declaringClass = null, module) + } + + override fun visibility(): Visibility { + return when (val visibility = function.visibility) { + DescriptorVisibilities.PUBLIC -> PUBLIC + DescriptorVisibilities.INTERNAL -> INTERNAL + DescriptorVisibilities.PROTECTED -> PROTECTED + DescriptorVisibilities.PRIVATE -> PRIVATE + else -> throw AnvilCompilationExceptionFunctionalReference( + functionalReference = this, + message = "Couldn't get visibility $visibility for function $fqName." + ) + } + } + } +} + +@ExperimentalAnvilApi +public fun KtFunction.toTopLevelFunctionReference( + module: AnvilModuleDescriptor, +): Psi { + return Psi(function = this, fqName = requireFqName(), module = module) +} + +@ExperimentalAnvilApi +public fun FunctionDescriptor.toTopLevelFunctionReference( + module: AnvilModuleDescriptor, +): Descriptor { + return Descriptor(function = this, module = module) +} diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt index 50c533fa7..b3a425394 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt @@ -15,7 +15,7 @@ import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryGenerator.Assis import com.squareup.anvil.compiler.internal.asClassName import com.squareup.anvil.compiler.internal.buildFile import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionClassReference -import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionReference +import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionalReference import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.FunctionReference import com.squareup.anvil.compiler.internal.reference.ParameterReference @@ -73,10 +73,10 @@ internal class AssistedFactoryGenerator : PrivateCodeGenerator() { function.function.resolveGenericReturnType(clazz) } catch (e: AnvilCompilationException) { // Catch the exception and throw the same error that Dagger would. - throw AnvilCompilationExceptionFunctionReference( + throw AnvilCompilationExceptionFunctionalReference( message = "Invalid return type: ${clazz.fqName}. An assisted factory's " + "abstract method must return a type with an @AssistedInject-annotated constructor.", - functionReference = function.function, + functionalReference = function.function, cause = e ) } diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/BindsMethodValidator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/BindsMethodValidator.kt index 4007ecbc3..9a0eda313 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/BindsMethodValidator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/BindsMethodValidator.kt @@ -6,7 +6,7 @@ import com.squareup.anvil.compiler.api.CodeGenerator import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator import com.squareup.anvil.compiler.daggerBindsFqName import com.squareup.anvil.compiler.daggerModuleFqName -import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionReference +import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionalReference import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.FunctionReference import com.squareup.anvil.compiler.internal.reference.allSuperTypeClassReferences @@ -53,9 +53,9 @@ internal class BindsMethodValidator : PrivateCodeGenerator() { private fun validateBindsFunction(function: FunctionReference.Psi) { if (!function.isAbstract()) { - throw AnvilCompilationExceptionFunctionReference( + throw AnvilCompilationExceptionFunctionalReference( message = "@Binds methods must be abstract", - functionReference = function + functionalReference = function ) } @@ -63,16 +63,16 @@ internal class BindsMethodValidator : PrivateCodeGenerator() { (function.parameters.size == 1 && !function.function.isExtensionDeclaration()) || (function.parameters.isEmpty() && function.function.isExtensionDeclaration()) if (!hasSingleBindingParameter) { - throw AnvilCompilationExceptionFunctionReference( + throw AnvilCompilationExceptionFunctionalReference( message = "@Binds methods must have exactly one parameter, " + "whose type is assignable to the return type", - functionReference = function + functionalReference = function ) } - function.returnTypeOrNull() ?: throw AnvilCompilationExceptionFunctionReference( + function.returnTypeOrNull() ?: throw AnvilCompilationExceptionFunctionalReference( message = "@Binds methods must return a value (not void)", - functionReference = function + functionalReference = function ) if (!function.parameterMatchesReturnType() && !function.receiverMatchesReturnType()) { @@ -86,11 +86,11 @@ internal class BindsMethodValidator : PrivateCodeGenerator() { } else { "only has the following supertypes: ${paramSuperTypes.drop(1)}" } - throw AnvilCompilationExceptionFunctionReference( + throw AnvilCompilationExceptionFunctionalReference( message = "@Binds methods' parameter type must be assignable to the return type. " + "Expected binding of type $returnType but impl parameter of type " + "${paramSuperTypes.first()} $superTypesMessage", - functionReference = function + functionalReference = function ) } } diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt index 34d380264..5896fe565 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/ProvidesMethodFactoryGenerator.kt @@ -12,7 +12,7 @@ import com.squareup.anvil.compiler.daggerProvidesFqName import com.squareup.anvil.compiler.internal.buildFile import com.squareup.anvil.compiler.internal.capitalize import com.squareup.anvil.compiler.internal.reference.AnnotatedReference -import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionReference +import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionalReference import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.FunctionReference import com.squareup.anvil.compiler.internal.reference.PropertyReference @@ -301,9 +301,9 @@ internal class ProvidesMethodFactoryGenerator : PrivateCodeGenerator() { clazz: ClassReference.Psi, function: FunctionReference.Psi ) { - fun fail(): Nothing = throw AnvilCompilationExceptionFunctionReference( + fun fail(): Nothing = throw AnvilCompilationExceptionFunctionalReference( message = "@Provides methods cannot be abstract", - functionReference = function + functionalReference = function ) // If the function is abstract, then it's an error. @@ -358,10 +358,10 @@ internal class ProvidesMethodFactoryGenerator : PrivateCodeGenerator() { function?.parameters?.mapToConstructorParameters() ?: emptyList() val type: TypeReference = function?.let { - it.returnTypeOrNull() ?: throw AnvilCompilationExceptionFunctionReference( + it.returnTypeOrNull() ?: throw AnvilCompilationExceptionFunctionalReference( message = "Dagger provider methods must specify the return type explicitly when using " + "Anvil. The return type cannot be inferred implicitly.", - functionReference = it + functionalReference = it ) } ?: property!!.type() val annotationReference: AnnotatedReference = function ?: property!! diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/reference/RealAnvilModuleDescriptor.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/reference/RealAnvilModuleDescriptor.kt index 12046df8a..c73bccc65 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/reference/RealAnvilModuleDescriptor.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/reference/RealAnvilModuleDescriptor.kt @@ -9,6 +9,8 @@ import com.squareup.anvil.compiler.internal.reference.AnvilModuleDescriptor import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.ClassReference.Descriptor import com.squareup.anvil.compiler.internal.reference.ClassReference.Psi +import com.squareup.anvil.compiler.internal.reference.TopLevelFunctionReference +import com.squareup.anvil.compiler.internal.reference.toTopLevelFunctionReference import com.squareup.anvil.compiler.internal.requireFqName import org.jetbrains.kotlin.descriptors.ClassDescriptor import org.jetbrains.kotlin.descriptors.ModuleDescriptor @@ -20,6 +22,7 @@ import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.psi.KtClassOrObject import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtFunction import org.jetbrains.kotlin.psi.psiUtil.parentsWithSelf import org.jetbrains.kotlin.resolve.descriptorUtil.classId import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe @@ -32,6 +35,9 @@ class RealAnvilModuleDescriptor private constructor( private val allPsiClassReferences: Sequence get() = ktFileToClassReferenceMap.values.asSequence().flatten() + private val ktFileToTopLevelFunctionReferenceMap = + mutableMapOf>() + private val resolveDescriptorCache = mutableMapOf() private val resolveClassIdCache = mutableMapOf() private val classReferenceCache = mutableMapOf() @@ -66,6 +72,12 @@ class RealAnvilModuleDescriptor private constructor( } } + override fun getTopLevelFunctionReferences(ktFile: KtFile): List { + return ktFileToTopLevelFunctionReferenceMap.getOrPut(ktFile.identifier) { + ktFile.topLevelFunctions().map { it.toTopLevelFunctionReference(this) } + } + } + override fun resolveClassIdOrNull(classId: ClassId): FqName? = resolveClassIdCache.getOrPut(classId) { val fqName = classId.asSingleFqName() @@ -166,6 +178,10 @@ private fun KtFile.classesAndInnerClasses(): List { }.flatten().toList() } +private fun KtFile.topLevelFunctions(): List { + return findChildrenByClass(KtFunction::class.java).toList() +} + private fun KtClassOrObject.toClassId(): ClassId { val className = parentsWithSelf.filterIsInstance() .toList() diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/internal/reference/TopLevelFunctionReferenceTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/internal/reference/TopLevelFunctionReferenceTest.kt new file mode 100644 index 000000000..60e2c4aee --- /dev/null +++ b/compiler/src/test/java/com/squareup/anvil/compiler/internal/reference/TopLevelFunctionReferenceTest.kt @@ -0,0 +1,161 @@ +package com.squareup.anvil.compiler.internal.reference + +import com.google.common.truth.Truth.assertThat +import com.squareup.anvil.compiler.api.AnvilContext +import com.squareup.anvil.compiler.api.CodeGenerator +import com.squareup.anvil.compiler.api.GeneratedFile +import com.squareup.anvil.compiler.compile +import com.squareup.anvil.compiler.internal.reference.Visibility.INTERNAL +import com.squareup.anvil.compiler.internal.reference.Visibility.PRIVATE +import com.squareup.anvil.compiler.internal.reference.Visibility.PUBLIC +import com.tschuchort.compiletesting.KotlinCompilation.ExitCode.OK +import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.descriptors.ModuleDescriptor +import org.jetbrains.kotlin.name.FqName +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtFunction +import org.jetbrains.kotlin.resolve.scopes.DescriptorKindFilter +import org.jetbrains.kotlin.resolve.source.getPsi +import org.junit.Test +import java.io.File + +class TopLevelFunctionReferenceTest { + + @Test fun `top level functions are parsed for PSI and Descriptor APIs correctly`() { + compile( + """ + package com.squareup.test + + private fun fun1() = Unit + + fun fun2(string: String): String = string + + @PublishedApi + internal fun fun3(): Int? = null + + fun fun4(): T = throw NotImplementedError() + + fun fun5(param1: T, param2: () -> T): T? = null + """, + allWarningsAsErrors = false, + codeGenerators = listOf( + object : CodeGenerator { + override fun isApplicable(context: AnvilContext): Boolean = true + + override fun generateCode( + codeGenDir: File, + module: ModuleDescriptor, + projectFiles: Collection + ): Collection { + val functions = projectFiles.topLevelFunctionReferences(module).toList() + assertThat(functions).hasSize(5) + + functions + .flatMap { listOf(it.toPsiReference(), it.toDescriptorReference()) } + .forEach { ref -> + when (ref.name) { + "fun1" -> { + assertThat(ref.fqName).isEqualTo(FqName("com.squareup.test.fun1")) + assertThat(ref.parameters).isEmpty() + assertThat(ref.annotations).isEmpty() + assertThat(ref.visibility()).isEqualTo(PRIVATE) + when (ref) { + is TopLevelFunctionReference.Psi -> + assertThat(ref.returnTypeOrNull()).isNull() + + is TopLevelFunctionReference.Descriptor -> + assertThat(ref.returnType().asClassReference().fqName.asString()) + .isEqualTo("kotlin.Unit") + } + } + + "fun2" -> { + assertThat(ref.fqName).isEqualTo(FqName("com.squareup.test.fun2")) + assertThat(ref.parameters.single().type().asClassReference().fqName.asString()) + .isEqualTo("kotlin.String") + assertThat(ref.annotations).isEmpty() + assertThat(ref.visibility()).isEqualTo(PUBLIC) + assertThat(ref.returnType().asClassReference().fqName.asString()) + .isEqualTo("kotlin.String") + } + + "fun3" -> { + assertThat(ref.fqName).isEqualTo(FqName("com.squareup.test.fun3")) + assertThat(ref.parameters).isEmpty() + assertThat(ref.annotations.single().fqName.asString()) + .isEqualTo("kotlin.PublishedApi") + assertThat(ref.visibility()).isEqualTo(INTERNAL) + assertThat(ref.returnType().asClassReference().fqName.asString()) + .isEqualTo("kotlin.Int") + assertThat(ref.returnType().isNullable()).isTrue() + } + + "fun4" -> { + assertThat(ref.fqName).isEqualTo(FqName("com.squareup.test.fun4")) + assertThat(ref.parameters).isEmpty() + assertThat(ref.annotations).isEmpty() + assertThat(ref.visibility()).isEqualTo(PUBLIC) + assertThat(ref.returnType().isGenericType()).isTrue() + } + + "fun5" -> { + assertThat(ref.fqName).isEqualTo(FqName("com.squareup.test.fun5")) + assertThat(ref.parameters).hasSize(2) + assertThat(ref.parameters[0].type().isGenericType()).isTrue() + assertThat(ref.parameters[1].type().isFunctionType()).isTrue() + assertThat(ref.annotations).isEmpty() + assertThat(ref.visibility()).isEqualTo(PUBLIC) + assertThat(ref.returnType().isGenericType()).isTrue() + assertThat(ref.returnType().isNullable()).isTrue() + } + + else -> throw NotImplementedError() + } + } + + return emptyList() + } + } + ) + ) { + assertThat(exitCode).isEqualTo(OK) + } + } +} + +fun TopLevelFunctionReference.toDescriptorReference(): TopLevelFunctionReference.Descriptor { + return when (this) { + is TopLevelFunctionReference.Descriptor -> this + is TopLevelFunctionReference.Psi -> { + // Force using the descriptor. + module.getPackage(fqName.parent()).memberScope + .getContributedDescriptors(DescriptorKindFilter.FUNCTIONS) + .filterIsInstance() + .single { it.name.asString() == name } + .toTopLevelFunctionReference(module) + .also { descriptorReference -> + assertThat(descriptorReference) + .isInstanceOf(TopLevelFunctionReference.Descriptor::class.java) + + assertThat(this).isEqualTo(descriptorReference) + assertThat(this.fqName).isEqualTo(descriptorReference.fqName) + } + } + } +} + +fun TopLevelFunctionReference.toPsiReference(): TopLevelFunctionReference.Psi { + return when (this) { + is TopLevelFunctionReference.Psi -> this + is TopLevelFunctionReference.Descriptor -> { + // Force using Psi. + (function.source.getPsi() as KtFunction).toTopLevelFunctionReference(module) + .also { psiReference -> + assertThat(psiReference).isInstanceOf(TopLevelFunctionReference.Psi::class.java) + + assertThat(this).isEqualTo(psiReference) + assertThat(this.fqName).isEqualTo(psiReference.fqName) + } + } + } +} diff --git a/integration-tests/code-generator-tests/src/main/java/com/squareup/anvil/test/Trigger.kt b/integration-tests/code-generator-tests/src/main/java/com/squareup/anvil/test/Trigger.kt index 7171617f9..bfd0e314b 100644 --- a/integration-tests/code-generator-tests/src/main/java/com/squareup/anvil/test/Trigger.kt +++ b/integration-tests/code-generator-tests/src/main/java/com/squareup/anvil/test/Trigger.kt @@ -4,3 +4,6 @@ annotation class Trigger @Trigger class AnyClass + +@Trigger +fun abc() = Unit diff --git a/integration-tests/code-generator-tests/src/test/java/com/squareup/anvil/test/GeneratedCodeTest.kt b/integration-tests/code-generator-tests/src/test/java/com/squareup/anvil/test/GeneratedCodeTest.kt index 6ba707d23..bb7996176 100644 --- a/integration-tests/code-generator-tests/src/test/java/com/squareup/anvil/test/GeneratedCodeTest.kt +++ b/integration-tests/code-generator-tests/src/test/java/com/squareup/anvil/test/GeneratedCodeTest.kt @@ -78,6 +78,12 @@ class GeneratedCodeTest { assertThat(int).isEqualTo(7) } + @Test fun `the function class is generated`() { + val generatedClass = + Class.forName("generated.test.com.squareup.anvil.test.GeneratedFunctionClass") + assertThat(generatedClass).isNotNull() + } + @MergeComponent(Unit::class) interface AppComponent { fun otherClass(): OtherClass diff --git a/integration-tests/code-generator/src/main/java/com/squareup/anvil/test/TestCodeGenerator.kt b/integration-tests/code-generator/src/main/java/com/squareup/anvil/test/TestCodeGenerator.kt index e57074c7d..28d831095 100644 --- a/integration-tests/code-generator/src/main/java/com/squareup/anvil/test/TestCodeGenerator.kt +++ b/integration-tests/code-generator/src/main/java/com/squareup/anvil/test/TestCodeGenerator.kt @@ -6,6 +6,7 @@ import com.squareup.anvil.compiler.api.CodeGenerator import com.squareup.anvil.compiler.api.GeneratedFile import com.squareup.anvil.compiler.api.createGeneratedFile import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferences +import com.squareup.anvil.compiler.internal.reference.topLevelFunctionReferences import com.squareup.anvil.compiler.internal.safePackageString import org.intellij.lang.annotations.Language import org.jetbrains.kotlin.descriptors.ModuleDescriptor @@ -173,6 +174,31 @@ class TestCodeGenerator : CodeGenerator { ), ) } + .plus( + projectFiles + .topLevelFunctionReferences(module) + .filter { it.isAnnotatedWith(FqName("com.squareup.anvil.test.Trigger")) } + .flatMap { + val generatedPackage = "generated.test" + it.fqName.parent() + .safePackageString(dotPrefix = true, dotSuffix = false) + + @Language("kotlin") + val generatedClass = """ + package $generatedPackage + + class GeneratedFunctionClass + """.trimIndent() + + listOf( + createGeneratedFile( + codeGenDir = codeGenDir, + packageName = generatedPackage, + fileName = "GeneratedFunctionClass", + content = generatedClass + ), + ) + } + ) .toList() } }