Skip to content

Commit

Permalink
Allow to query top level functions as references for code generators.
Browse files Browse the repository at this point in the history
Fixes #644
  • Loading branch information
vRallev committed Jan 11, 2023
1 parent 0315374 commit 432b6d4
Show file tree
Hide file tree
Showing 14 changed files with 451 additions and 75 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

### Custom Code Generator

- Add ability to query top-level functions. This allows you write code generators for top-level functions, see #644.

## [2.4.3] - 2022-12-16

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public interface AnvilModuleDescriptor : ModuleDescriptor {

public fun getClassAndInnerClassReferences(ktFile: KtFile): List<Psi>

public fun getTopLevelFunctionReferences(ktFile: KtFile): List<TopLevelFunctionReference.Psi>

public fun getClassReference(clazz: KtClassOrObject): Psi

public fun getClassReference(descriptor: ClassDescriptor): Descriptor
Expand Down Expand Up @@ -58,3 +60,12 @@ public fun Collection<KtFile>.classAndInnerClassReferences(
module.asAnvilModuleDescriptor().getClassAndInnerClassReferences(it)
}
}

@ExperimentalAnvilApi
public fun Collection<KtFile>.topLevelFunctionReferences(
module: ModuleDescriptor
): Sequence<TopLevelFunctionReference.Psi> {
return asSequence().flatMap {
module.asAnvilModuleDescriptor().getTopLevelFunctionReferences(it)
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.squareup.anvil.compiler.internal.reference

import com.squareup.anvil.annotations.ExperimentalAnvilApi
import com.squareup.anvil.compiler.api.AnvilCompilationException
import com.squareup.anvil.compiler.internal.reference.FunctionReference.Descriptor
import com.squareup.anvil.compiler.internal.reference.FunctionReference.Psi
import com.squareup.anvil.compiler.internal.reference.Visibility.INTERNAL
Expand Down Expand Up @@ -30,29 +29,18 @@ import kotlin.LazyThreadSafetyMode.NONE
* [FunctionDescriptor] references, to streamline parsing.
*/
@ExperimentalAnvilApi
public sealed class FunctionReference : AnnotatedReference {
public sealed class FunctionReference : AnnotatedReference, FunctionalReference {

public abstract val fqName: FqName
public abstract val declaringClass: ClassReference
public val name: String get() = fqName.shortName().asString()

public val module: AnvilModuleDescriptor get() = declaringClass.module

public abstract val parameters: List<ParameterReference>
public override val module: AnvilModuleDescriptor get() = declaringClass.module

protected abstract val returnType: TypeReference?

public fun returnTypeOrNull(): TypeReference? = returnType
public fun returnType(): TypeReference = returnType
?: throw AnvilCompilationExceptionFunctionReference(
functionReference = this,
message = "Unable to get the return type for function $fqName."
)
public override fun returnTypeOrNull(): TypeReference? = returnType

public abstract fun isAbstract(): Boolean
public abstract fun isConstructor(): Boolean
public abstract fun visibility(): Visibility

public fun resolveGenericReturnTypeOrNull(
implementingClass: ClassReference
): ClassReference? {
Expand All @@ -63,8 +51,8 @@ public sealed class FunctionReference : AnnotatedReference {

public fun resolveGenericReturnType(implementingClass: ClassReference): ClassReference =
resolveGenericReturnTypeOrNull(implementingClass)
?: throw AnvilCompilationExceptionFunctionReference(
functionReference = this,
?: throw AnvilCompilationExceptionFunctionalReference(
functionalReference = this,
message = "Unable to resolve return type for function $fqName with the implementing " +
"class ${implementingClass.fqName}."
)
Expand All @@ -73,7 +61,7 @@ public sealed class FunctionReference : AnnotatedReference {

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is ClassReference) return false
if (other !is FunctionReference) return false

if (fqName != other.fqName) return false

Expand All @@ -85,10 +73,10 @@ public sealed class FunctionReference : AnnotatedReference {
}

public class Psi internal constructor(
public val function: KtFunction,
public override val function: KtFunction,
override val declaringClass: ClassReference.Psi,
override val fqName: FqName
) : FunctionReference() {
) : FunctionReference(), FunctionalReference.Psi {

override val annotations: List<AnnotationReference.Psi> by lazy(NONE) {
function.annotationEntries.map {
Expand Down Expand Up @@ -125,10 +113,10 @@ public sealed class FunctionReference : AnnotatedReference {
}

public class Descriptor internal constructor(
public val function: FunctionDescriptor,
public override val function: FunctionDescriptor,
override val declaringClass: ClassReference.Descriptor,
override val fqName: FqName = function.fqNameSafe
) : FunctionReference() {
) : FunctionReference(), FunctionalReference.Descriptor {

override val annotations: List<AnnotationReference.Descriptor> by lazy(NONE) {
function.annotations.map {
Expand All @@ -144,16 +132,6 @@ public sealed class FunctionReference : AnnotatedReference {
function.returnType?.toTypeReference(declaringClass, module)
}

internal val overriddenFunctions by lazy(NONE) {
generateSequence(
function.overriddenDescriptors
) { overriddenFunctions ->
overriddenFunctions
.flatMap { it.overriddenDescriptors }
.takeIf { it.isNotEmpty() }
}.flatten().toList()
}

override fun isAbstract(): Boolean = function.modality == ABSTRACT

override fun isConstructor(): Boolean = function is ClassConstructorDescriptor
Expand Down Expand Up @@ -192,22 +170,3 @@ public fun FunctionDescriptor.toFunctionReference(
): Descriptor {
return Descriptor(this, declaringClass)
}

@ExperimentalAnvilApi
@Suppress("FunctionName")
public fun AnvilCompilationExceptionFunctionReference(
functionReference: FunctionReference,
message: String,
cause: Throwable? = null
): AnvilCompilationException = when (functionReference) {
is Psi -> AnvilCompilationException(
element = functionReference.function,
message = message,
cause = cause
)
is Descriptor -> AnvilCompilationException(
functionDescriptor = functionReference.function,
message = message,
cause = cause
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.squareup.anvil.compiler.internal.reference

import com.squareup.anvil.annotations.ExperimentalAnvilApi
import com.squareup.anvil.compiler.api.AnvilCompilationException
import com.squareup.anvil.compiler.internal.reference.FunctionalReference.Descriptor
import com.squareup.anvil.compiler.internal.reference.FunctionalReference.Psi
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtFunction

@ExperimentalAnvilApi
public sealed interface FunctionalReference {
public val fqName: FqName
public val name: String get() = fqName.shortName().asString()

public val module: AnvilModuleDescriptor

public val parameters: List<ParameterReference>

public fun returnTypeOrNull(): TypeReference?
public fun returnType(): TypeReference = returnTypeOrNull()
?: throw AnvilCompilationExceptionFunctionalReference(
functionalReference = this,
message = "Unable to get the return type for function $fqName."
)

public fun visibility(): Visibility

public sealed interface Psi : FunctionalReference {
public val function: KtFunction
}

public sealed interface Descriptor : FunctionalReference {
public val function: FunctionDescriptor
}
}

@ExperimentalAnvilApi
@Suppress("FunctionName")
public fun AnvilCompilationExceptionFunctionalReference(
functionalReference: FunctionalReference,
message: String,
cause: Throwable? = null
): AnvilCompilationException = when (functionalReference) {
is Psi -> AnvilCompilationException(
element = functionalReference.function,
message = message,
cause = cause
)
is Descriptor -> AnvilCompilationException(
functionDescriptor = functionalReference.function,
message = message,
cause = cause
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ import kotlin.LazyThreadSafetyMode.NONE
public sealed class ParameterReference : AnnotatedReference {

public abstract val name: String
public abstract val declaringFunction: FunctionReference
public abstract val declaringFunction: FunctionalReference
public val module: AnvilModuleDescriptor get() = declaringFunction.module

protected abstract val type: TypeReference?

protected abstract val declaringClass: ClassReference?

/**
* The type can be null for generic type parameters like `T`. In this case try to resolve the
* type with [TypeReference.resolveGenericTypeOrNull].
Expand Down Expand Up @@ -66,7 +68,7 @@ public sealed class ParameterReference : AnnotatedReference {

public class Psi(
public val parameter: KtParameter,
override val declaringFunction: FunctionReference.Psi
override val declaringFunction: FunctionalReference.Psi
) : ParameterReference() {
override val name: String = parameter.nameAsSafeName.asString()

Expand All @@ -77,13 +79,19 @@ public sealed class ParameterReference : AnnotatedReference {
}

override val type: TypeReference.Psi? by lazy(NONE) {
parameter.typeReference?.toTypeReference(declaringFunction.declaringClass, module)
parameter.typeReference?.toTypeReference(declaringClass, module)
}

override val declaringClass: ClassReference.Psi?
get() = when (declaringFunction) {
is TopLevelFunctionReference.Psi -> null
is FunctionReference.Psi -> declaringFunction.declaringClass
}
}

public class Descriptor(
public val parameter: ValueParameterDescriptor,
override val declaringFunction: FunctionReference.Descriptor
override val declaringFunction: FunctionalReference.Descriptor
) : ParameterReference() {
override val name: String = parameter.name.asString()

Expand All @@ -94,21 +102,27 @@ public sealed class ParameterReference : AnnotatedReference {
}

override val type: TypeReference.Descriptor? by lazy(NONE) {
parameter.type.toTypeReference(declaringFunction.declaringClass, module)
parameter.type.toTypeReference(declaringClass, module)
}

override val declaringClass: ClassReference.Descriptor?
get() = when (declaringFunction) {
is TopLevelFunctionReference.Descriptor -> null
is FunctionReference.Descriptor -> declaringFunction.declaringClass
}
}
}

@ExperimentalAnvilApi
public fun KtParameter.toParameterReference(
declaringFunction: FunctionReference.Psi
declaringFunction: FunctionalReference.Psi
): Psi {
return Psi(this, declaringFunction)
}

@ExperimentalAnvilApi
public fun ValueParameterDescriptor.toParameterReference(
declaringFunction: FunctionReference.Descriptor
declaringFunction: FunctionalReference.Descriptor
): Descriptor {
return Descriptor(this, declaringFunction)
}
Expand Down
Loading

0 comments on commit 432b6d4

Please sign in to comment.