diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 8100f78e50e7..5c9946f6134a 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -203,10 +203,6 @@ extension (tp: Type) case _ => false - def isCapabilityClassRef(using Context) = tp.dealiasKeepAnnots match - case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot) - case _ => false - /** Drop @retains annotations everywhere */ def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling val tm = new TypeMap: diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index b60246af1f36..a5bb8792af2c 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -537,7 +537,8 @@ class CheckCaptures extends Recheck, SymTransformer: */ def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) = var refined: Type = core - var allCaptures: CaptureSet = initCs + var allCaptures: CaptureSet = if setup.isCapabilityClassRef(core) + then CaptureSet.universal else initCs for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol if getter.termRef.isTracked && !getter.is(Private) then diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 657e772b87ae..e6953dbf67b7 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -23,6 +23,7 @@ trait SetupAPI: def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit def isPreCC(sym: Symbol)(using Context): Boolean def postCheck()(using Context): Unit + def isCapabilityClassRef(tp: Type)(using Context): Boolean object Setup: @@ -67,6 +68,31 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: && !sym.owner.is(CaptureChecked) && !defn.isFunctionSymbol(sym.owner) + private val capabilityClassMap = new util.HashMap[Symbol, Boolean] + + /** Check if the class is capability, which means: + * 1. the class has a capability annotation, + * 2. or at least one of its parent type has universal capability. + */ + def isCapabilityClassRef(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match + case _: TypeRef | _: AppliedType => + val sym = tp.classSymbol + def checkSym: Boolean = + sym.hasAnnotation(defn.CapabilityAnnot) + || sym.info.parents.exists(hasUniversalCapability) + sym.isClass && capabilityClassMap.getOrElseUpdate(sym, checkSym) + case _ => false + + private def hasUniversalCapability(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match + case CapturingType(parent, refs) => + refs.isUniversal || hasUniversalCapability(parent) + case AnnotatedType(parent, ann) => + if ann.symbol.isRetains then + try ann.tree.toCaptureSet.isUniversal || hasUniversalCapability(parent) + catch case ex: IllegalCaptureRef => false + else hasUniversalCapability(parent) + case tp => isCapabilityClassRef(tp) + private def fluidify(using Context) = new TypeMap with IdempotentCaptRefMap: def apply(t: Type): Type = t match case t: MethodType => @@ -269,12 +295,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: CapturingType(fntpe, cs, boxed = false) else fntpe - /** Map references to capability classes C to C^ */ - private def expandCapabilityClass(tp: Type): Type = - if tp.isCapabilityClassRef - then CapturingType(tp, defn.expandedUniversalSet, boxed = false) - else tp - private def recur(t: Type): Type = normalizeCaptures(mapOver(t)) def apply(t: Type) = @@ -297,7 +317,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: case t: TypeVar => this(t.underlying) case t => - if t.isCapabilityClassRef + // Map references to capability classes C to C^ + if isCapabilityClassRef(t) then CapturingType(t, defn.expandedUniversalSet, boxed = false) else recur(t) end expandAliases diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index 4de93aa9c2a9..09d45dbdf06b 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import config.Config import reporting.* import collection.mutable -import cc.{CapturingType, derivedCapturingType} +import cc.{CapturingType, derivedCapturingType, stripCapturing} import scala.annotation.internal.sharable import scala.compiletime.uninitialized @@ -2225,7 +2225,7 @@ object SymDenotations { tp match { case tp @ TypeRef(prefix, _) => def foldGlb(bt: Type, ps: List[Type]): Type = ps match { - case p :: ps1 => foldGlb(bt & recur(p), ps1) + case p :: ps1 => foldGlb(bt & recur(p.stripCapturing), ps1) case _ => bt } diff --git a/tests/neg-custom-args/captures/extending-cap-classes.scala b/tests/neg-custom-args/captures/extending-cap-classes.scala new file mode 100644 index 000000000000..17497e415a1e --- /dev/null +++ b/tests/neg-custom-args/captures/extending-cap-classes.scala @@ -0,0 +1,15 @@ +import annotation.capability + +class C1 +@capability class C2 extends C1 +class C3 extends C2 + +def test = + val x1: C1 = new C1 + val x2: C1 = new C2 // error + val x3: C1 = new C3 // error + + val y1: C2 = new C2 + val y2: C2 = new C3 + + val z1: C3 = new C3 \ No newline at end of file diff --git a/tests/neg-custom-args/captures/extending-impure-function.scala b/tests/neg-custom-args/captures/extending-impure-function.scala new file mode 100644 index 000000000000..e491b31caed5 --- /dev/null +++ b/tests/neg-custom-args/captures/extending-impure-function.scala @@ -0,0 +1,30 @@ +class F1 extends (Int => Unit) { + def apply(x: Int): Unit = () +} + +class F2 extends (Int -> Unit) { + def apply(x: Int): Unit = () +} + +def test = + val x1 = new (Int => Unit) { + def apply(x: Int): Unit = () + } + + val x2: Int -> Unit = new (Int => Unit) { // error + def apply(x: Int): Unit = () + } + + val x3: Int -> Unit = new (Int -> Unit) { + def apply(x: Int): Unit = () + } + + val y1: Int => Unit = new F1 + val y2: Int -> Unit = new F1 // error + val y3: Int => Unit = new F2 + val y4: Int -> Unit = new F2 + + val z1 = () => () + val z2: () -> Unit = () => () + val z3: () -> Unit = z1 + val z4: () => Unit = () => ()