Skip to content

Commit

Permalink
SIP-61 - fixed unpickling errors - invisible select and incorrect spans
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed Oct 2, 2024
1 parent 818bd51 commit 5b253a0
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 85 deletions.
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,8 @@ class Definitions {
@tu lazy val NowarnAnnot: ClassSymbol = requiredClass("scala.annotation.nowarn")
@tu lazy val UnusedAnnot: ClassSymbol = requiredClass("scala.annotation.unused")
@tu lazy val UnrollAnnot: ClassSymbol = requiredClass("scala.annotation.unroll")
@tu lazy val AbstractUnrollAnnot: ClassSymbol = requiredClass("scala.annotation.internal.AbstractUnroll")
@tu lazy val UnrollForwarderAnnot: ClassSymbol = requiredClass("scala.annotation.internal.UnrollForwarder")
@tu lazy val TransparentTraitAnnot: ClassSymbol = requiredClass("scala.annotation.transparentTrait")
@tu lazy val NativeAnnot: ClassSymbol = requiredClass("scala.native")
@tu lazy val RepeatedAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Repeated")
Expand Down
15 changes: 10 additions & 5 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import collection.mutable
import reporting.{Profile, NoProfile}
import dotty.tools.tasty.TastyFormat.ASTsSection
import quoted.QuotePatterns
import dotty.tools.dotc.config.Feature

object TreePickler:
class StackSizeExceeded(val mdef: tpd.MemberDef) extends Exception
Expand Down Expand Up @@ -474,26 +475,30 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
case _ =>
if passesConditionForErroringBestEffortCode(tree.hasType) then
// #19951 The signature of a constructor of a Java annotation is irrelevant
val sym = tree.symbol
val sig =
if name == nme.CONSTRUCTOR && tree.symbol.exists && tree.symbol.owner.is(JavaAnnotation) then Signature.NotAMethod
if name == nme.CONSTRUCTOR && sym.exists && sym.owner.is(JavaAnnotation) then Signature.NotAMethod
else tree.tpe.signature
var ename = tree.symbol.targetName
var ename = sym.targetName
val selectFromQualifier =
name.isTypeName
|| qual.isInstanceOf[Hole] // holes have no symbol
|| sig == Signature.NotAMethod // no overload resolution necessary
|| !tree.denot.symbol.exists // polymorphic function type
|| !sym.exists // polymorphic function type
|| tree.denot.asSingleDenotation.isRefinedMethod // refined methods have no defining class symbol
if selectFromQualifier then
writeByte(if name.isTypeName then SELECTtpt else SELECT)
pickleNameAndSig(name, sig, ename)
pickleTree(qual)
else if sym.is(Invisible) && qual.isInstanceOf[This] && sym.hasAnnotation(defn.UnrollForwarderAnnot) then
writeByte(TERMREFdirect)
pickleSymRef(sym) // SIP-61 HACK: resolution from Signature filters out Invisible symbols
else // select from owner
writeByte(SELECTin)
withLength {
pickleNameAndSig(name, tree.symbol.signature, ename)
pickleNameAndSig(name, sym.signature, ename)
pickleTree(qual)
pickleType(tree.symbol.owner.typeRef)
pickleType(sym.owner.typeRef)
}
else
writeByte(if name.isTypeName then SELECTtpt else SELECT)
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,12 @@ class TreeUnpickler(reader: TastyReader,
goto(start)
readType() match {
case path: TypeRef => TypeTree(path)
case path: TermRef => ref(path)
case path: TermRef =>
val sym = path.symbol
if sym.is(Invisible) && sym.hasAnnotation(defn.UnrollForwarderAnnot) then
This(sym.owner.asClass).select(sym)
else
ref(path)
case path: ThisType => untpd.This(untpd.EmptyTypeIdent).withType(path)
case path: ConstantType => Literal(path.value)
case path: ErrorType if isBestEffortTasty => TypeTree(path)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/Pickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ class Pickler extends Phase {
case None =>
()

if ctx.settings.YtestPicklerCheck.value then
sys.error(printedTasty(cls))

inContext(printerContext(testJava)(using rootCtx.fresh.setCompilationUnit(freshUnit))):
testSame(i"$unpickled%\n%", beforePickling(cls), cls)

Expand Down
186 changes: 115 additions & 71 deletions compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {

import tpd.*

private var _unrolledDefs: util.HashMap[Symbol, ComputedIndicies] | Null = null
private def initializeUnrolledDefs(): util.HashMap[Symbol, ComputedIndicies] =
val local = _unrolledDefs
if local == null then
val map = new util.HashMap[Symbol, ComputedIndicies]
_unrolledDefs = map
map
else
local.clear()
local

override def phaseName: String = UnrollDefinitions.name

override def description: String = UnrollDefinitions.description
Expand All @@ -38,21 +49,40 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {

override def run(using Context): Unit =
if ctx.compilationUnit.hasUnrollDefs then
super.run
super.run // create and run the transformer on the current compilation unit

def newTransformer(using Context): Transformer =
UnrollingTransformer(ctx.compilationUnit.unrolledClasses.nn)

type ComputedIndicies = Seq[(Int, List[Int])]
type ComputeIndicies = Context ?=> Symbol => ComputedIndicies

private class UnrollingTransformer(classes: Set[Symbol]) extends Transformer {
private val unrolledDefs = initializeUnrolledDefs()

def computeIndices(annotated: Symbol)(using Context): ComputedIndicies =
unrolledDefs.getOrElseUpdate(annotated, {
annotated
.paramSymss
.zipWithIndex
.flatMap { (paramClause, paramClauseIndex) =>
val annotationIndices = findUnrollAnnotations(paramClause)
if (annotationIndices.isEmpty) None
else Some((paramClauseIndex, annotationIndices))
}
})
end computeIndices

override def transform(tree: tpd.Tree)(using Context): tpd.Tree = tree match
case tree @ TypeDef(_, impl: Template) if classes(tree.symbol) =>
super.transform(cpy.TypeDef(tree)(rhs = unrollTemplate(impl)))
super.transform(cpy.TypeDef(tree)(rhs = unrollTemplate(impl, computeIndices)))
case tree =>
super.transform(tree)
}

def copyParamSym(sym: Symbol, parent: Symbol)(using Context): (Symbol, Symbol) =
sym -> sym.copy(owner = parent, flags = (sym.flags &~ HasDefault))
val copied = sym.copy(owner = parent, flags = (sym.flags &~ HasDefault), coord = sym.coord)
sym -> copied

def symLocation(sym: Symbol)(using Context) = {
val lineDesc =
Expand All @@ -79,46 +109,62 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
nextParamIndex: Int,
nextSymbol: Symbol,
annotatedParamListIndex: Int,
isCaseApply: Boolean)(using Context) = {
isCaseApply: Boolean,
inferOverride: Boolean)(using Context) = {

def initNewForwarder()(using Context): (TermSymbol, List[List[Symbol]]) = {
val forwarderDefSymbol0 = Symbols.newSymbol(
defdef.symbol.owner,
defdef.name,
(defdef.symbol.flags &~
HasDefaultParams &~
(if nextParamIndex == -1 then EmptyFlags else Deferred)) |
Invisible | Synthetic |
(if inferOverride then Override else EmptyFlags),
NoType, // fill in later
coord = nextSymbol.span.shift(1) // shift by 1 to avoid "secondary constructor must call preceding" error
).entered

// we need this such that when unpickling a TERMREFdirect, if we see this annotation,
// we restore the tree to a Select
forwarderDefSymbol0.addAnnotation(defn.UnrollForwarderAnnot)

val newParamSymMappings = extractParamSymss(copyParamSym(_, forwarderDefSymbol0))
val (oldParams, newParams) = newParamSymMappings.flatten.unzip

val newParamSymLists0 =
newParamSymMappings.map: pairss =>
pairss.map: (oldSym, newSym) =>
newSym.info = oldSym.info.substSym(oldParams, newParams)
newSym

val newResType = defdef.tpt.tpe.substSym(oldParams, newParams)
forwarderDefSymbol0.info = NamerOps.methodType(newParamSymLists0, newResType)
forwarderDefSymbol0.setParamss(newParamSymLists0)
forwarderDefSymbol0 -> newParamSymLists0
}

def extractParamSymss(parent: Symbol)(using Context): List[List[(Symbol, Symbol)]] =
def extractParamSymss[T](onSymbol: Symbol => T): List[List[T]] =
defdef.paramss.zipWithIndex.map{ case (ps, i) =>
if (i == annotatedParamListIndex) ps.take(paramIndex).map(p => copyParamSym(p.symbol, parent))
else ps.map(p => copyParamSym(p.symbol, parent))
if (i == annotatedParamListIndex) ps.take(paramIndex).map(p => onSymbol(p.symbol))
else ps.map(p => onSymbol(p.symbol))
}

val isOverride = {
val candidate = defdef.symbol.nextOverriddenSymbol
candidate.exists && !candidate.is(Deferred)
}

val forwarderDefSymbol = Symbols.newSymbol(
defdef.symbol.owner,
defdef.name,
defdef.symbol.flags &~
HasDefaultParams &~
(if (nextParamIndex == -1) Flags.EmptyFlags else Deferred) |
Invisible | (if (isOverride) Override else EmptyFlags),
NoType, // fill in later
).entered

if nextParamIndex == -1 then
defdef.symbol.owner.asClass.info.decls.openForMutations.unlink(defdef.symbol)
else if isOverride then
defdef.symbol.flags_=(defdef.symbol.flags | Override)

val newParamSymMappings = extractParamSymss(forwarderDefSymbol)
val (oldParams, newParams) = newParamSymMappings.flatten.unzip
val paramCount = defdef.symbol.paramSymss(annotatedParamListIndex).size
val isDeferredInitial = paramCount == paramIndex && defdef.symbol.is(Deferred)

val newParamSymLists =
newParamSymMappings.map: pairss =>
pairss.map: (oldSym, newSym) =>
newSym.info = oldSym.info.substSym(oldParams, newParams)
newSym
val (forwarderDefSymbol, newParamSymLists) =
if isDeferredInitial then
val existing = defdef.symbol.asTerm
existing.addAnnotation(defn.AbstractUnrollAnnot) // mark as previously abstract
existing.flags = (existing.flags &~ Deferred) // going to implement its rhs
existing -> extractParamSymss(identity)
else
initNewForwarder()

val newResType = defdef.tpt.tpe.substSym(oldParams, newParams)
forwarderDefSymbol.info = NamerOps.methodType(newParamSymLists, newResType)
forwarderDefSymbol.setParamss(newParamSymLists)
if inferOverride then
// in this case we will not replace the source method, but we will add the override flag
defdef.symbol.flags_=(defdef.symbol.flags | Override)

def forwarderRhs(): tpd.Tree = {
val defaultOffset = defdef.paramss
Expand Down Expand Up @@ -185,9 +231,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {

val forwarderDef =
tpd.DefDef(forwarderDefSymbol,
rhs = if (nextParamIndex == -1) EmptyTree else forwarderRhs())
rhs = if nextParamIndex == -1 then EmptyTree else forwarderRhs())

forwarderDef
forwarderDef.withSpan(if isDeferredInitial then defdef.span else nextSymbol.span.shift(1))
}

def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
Expand Down Expand Up @@ -222,10 +268,25 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
).setDefTree
}

def generateSyntheticDefs(tree: Tree)(using Context): (Option[(Symbol, Seq[Symbol])], Seq[(Symbol, Tree)]) = tree match {
def generateSyntheticDefs(tree: Tree, compute: ComputeIndicies)(using Context): (Option[Symbol], Seq[(Symbol, Tree)]) = tree match {
case defdef: DefDef if defdef.paramss.nonEmpty =>
import dotty.tools.dotc.core.NameOps.isConstructorName

// infer an override when we are implementing a method that matches the signature and has unroll annotations
// in the same positions
lazy val inferOverride = {
def unrollIndices(sym: Symbol): List[Int] =
sym.paramSymss.flatten.zipWithIndex.collect({
case (p, i) if p.hasAnnotation(defn.UnrollAnnot) => i
})

val candidate = defdef.symbol.nextOverriddenSymbol
candidate.exists && !candidate.is(Deferred) && candidate.hasAnnotation(defn.AbstractUnrollAnnot) && {
// check unroll indices match
unrollIndices(candidate) == unrollIndices(defdef.symbol)
}
}

val isCaseCopy =
defdef.name.toString == "copy" && defdef.symbol.owner.is(CaseClass)

Expand All @@ -240,24 +301,17 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
else defdef.symbol

annotated
.paramSymss
.zipWithIndex
.flatMap { (paramClause, paramClauseIndex) =>
val annotationIndices = findUnrollAnnotations(paramClause)
if (annotationIndices.isEmpty) None
else Some((paramClauseIndex, annotationIndices))
} match {
compute(annotated) match {
case Nil => (None, Nil)
case Seq((paramClauseIndex, annotationIndices)) =>
val paramCount = annotated.paramSymss(paramClauseIndex).size
if (isCaseFromProduct) {
val newDef = generateFromProduct(annotationIndices, paramCount, defdef)
(Some(defdef.symbol, Seq(newDef.symbol)), Seq(defdef.symbol -> newDef))
(Some(defdef.symbol), Seq(defdef.symbol -> newDef))
} else {
if (defdef.symbol.is(Deferred)){
val replacements = Seq.newBuilder[Symbol]
val newDefs =
(
Some(defdef.symbol),
(-1 +: annotationIndices :+ paramCount).sliding(2).toList.foldLeft((Seq.empty[(Symbol, DefDef)], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
Expand All @@ -267,18 +321,15 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
paramIndex,
nextSymbol,
paramClauseIndex,
isCaseApply
isCaseApply,
inferOverride
)
replacements += forwarder.symbol
// replacements += forwarder.symbol
((defdef.symbol -> forwarder) +: defdefs, forwarder.symbol)
})._1
(
Some(defdef.symbol, replacements.result()),
newDefs
)

}else{

(
None,
(annotationIndices :+ paramCount).sliding(2).toList.reverse.foldLeft((Seq.empty[(Symbol, DefDef)], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
Expand All @@ -290,7 +341,8 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
nextParamIndex,
nextSymbol,
paramClauseIndex,
isCaseApply
isCaseApply,
inferOverride
)
((defdef.symbol -> forwarder) +: defdefs, forwarder.symbol)
})._1
Expand All @@ -304,27 +356,19 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
case _ => (None, Nil)
}

def unrollTemplate(tmpl: tpd.Template)(using Context): tpd.Tree = {
def unrollTemplate(tmpl: tpd.Template, compute: ComputeIndicies)(using Context): tpd.Tree = {

val (removed0, generatedDefs0) = tmpl.body.map(generateSyntheticDefs).unzip
val (removedCtor, generatedConstr0) = generateSyntheticDefs(tmpl.constr)
val removedFlat = removed0.flatten
val removedSymsBody = removedFlat.map(_(0))
val allRemoved = removedFlat ++ removedCtor
val (removed0, generatedDefs0) = tmpl.body.map(generateSyntheticDefs(_, compute)).unzip
val (removedCtor, generatedConstr0) = generateSyntheticDefs(tmpl.constr, compute)
val removedSymsBody = removed0.flatten
val allRemoved = removedSymsBody ++ removedCtor

val generatedDefOrigins = generatedDefs0.flatten
val generatedDefs = generatedDefOrigins.map(_(1))
val generatedConstr = generatedConstr0.map(_(1))

val otherDecls = tmpl.body.filter(t => !removedSymsBody.contains(t.symbol))

for (sym, replacements) <- allRemoved do
val cls = sym.owner.asClass
def totalParamCount(sym: Symbol): Int = sym.paramSymss.view.map(_.size).sum
val symParamCount = totalParamCount(sym)
val replaced = replacements.find(totalParamCount(_) == symParamCount).get
cls.replace(sym, replaced)

/** inlined from compiler/src/dotty/tools/dotc/typer/Checking.scala */
def checkClash(decl: Symbol, other: Symbol) =
def staticNonStaticPair = decl.isScalaStatic != other.isScalaStatic
Expand Down
10 changes: 10 additions & 0 deletions library/src/scala/annotation/internal/AbstractUnroll.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package scala.annotation.internal

import scala.annotation.Annotation
import scala.annotation.experimental

/** Indicates the method was abstract in source code before unrolling transformation was added.
* downstream direct overrides, with matching `@unroll` annotations will infer an override.
*/
@experimental("under review as part of SIP-61")
final class AbstractUnroll extends Annotation
9 changes: 9 additions & 0 deletions library/src/scala/annotation/internal/UnrollForwarder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package scala.annotation.internal

import scala.annotation.Annotation
import scala.annotation.experimental

/** This method was generated via `@unroll` annotation.
*/
@experimental("under review as part of SIP-61")
final class UnrollForwarder extends Annotation
Loading

0 comments on commit 5b253a0

Please sign in to comment.