Skip to content

Commit

Permalink
Add support for some type aliases, when expanding context bounds for …
Browse files Browse the repository at this point in the history
…poly functions
  • Loading branch information
KacperFKorban committed Oct 3, 2024
1 parent 458fd29 commit 42d914e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 27 deletions.
26 changes: 17 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object desugar {
/** An attachment key to indicate that a DefDef is a poly function apply
* method definition.
*/
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
Expand Down Expand Up @@ -514,17 +514,25 @@ object desugar {
case Nil =>
params :: Nil

// TODO(kπ) is this enough? SHould this be a TreeTraverse-thing?
def pushDownEvidenceParams(tree: Tree): Tree = tree match
case Function(params, body) =>
cpy.Function(tree)(params, pushDownEvidenceParams(body))
case Block(stats, expr) =>
cpy.Block(tree)(stats, pushDownEvidenceParams(expr))
case tree =>
val paramTpts = params.map(_.tpt)
val paramNames = params.map(_.name)
val paramsErased = params.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)

if meth.hasAttachment(PolyFunctionApply) then
meth.removeAttachment(PolyFunctionApply)
val paramTpts = params.map(_.tpt)
val paramNames = params.map(_.name)
val paramsErased = params.map(_.mods.flags.is(Erased))
// (kπ): deffer this until we can type the result?
if ctx.mode.is(Mode.Type) then
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased)
cpy.DefDef(meth)(tpt = ctxFunction)
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
else
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased)
cpy.DefDef(meth)(rhs = ctxFunction)
cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs))
else
cpy.DefDef(meth)(paramss = recur(meth.paramss))
end addEvidenceParams
Expand Down Expand Up @@ -1251,7 +1259,7 @@ object desugar {
RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, ())
.withAttachment(PolyFunctionApply, List.empty)
)).withSpan(tree.span)
end makePolyFunctionType

Expand Down
56 changes: 41 additions & 15 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ import config.MigrationVersion
import transform.CheckUnused.OriginalName

import scala.annotation.constructorOnly
import dotty.tools.dotc.ast.desugar.PolyFunctionApply

object Typer {

Expand Down Expand Up @@ -1958,7 +1957,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
defdef.putAttachment(PolyFunctionApply, ())
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
typed(desugared, pt)
else
val msg =
Expand All @@ -1967,7 +1966,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
errorTree(EmptyTree, msg, tree.srcPos)
case _ =>
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
defdef.putAttachment(PolyFunctionApply, ())
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
typed(desugared, pt)
end typedPolyFunctionValue

Expand Down Expand Up @@ -3580,30 +3579,57 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case xtree => typedUnnamed(xtree)

val unsimplifiedType = result.tpe
simplify(result, pt, locked)
result.tpe.stripTypeVar match
val result1 = simplify(result, pt, locked)
result1.tpe.stripTypeVar match
case e: ErrorType if !unsimplifiedType.isErroneous => errorTree(xtree, e.msg, xtree.srcPos)
case _ => result
case _ => result1
catch case ex: TypeError =>
handleTypeError(ex)
}
}

private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
case tpe: MethodType =>
MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
case tpe: PolyType =>
PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
case tpe: RefinedType =>
// TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
case tpe =>
val paramNames = params.map(_.name)
val paramTpts = params.map(_.tpt)
val paramsErased = params.map(_.mods.flags.is(Erased))
val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span)
typed(ctxFunction).tpe
}

private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
tree.getAttachment(desugar.PolyFunctionApply) match
case Some(params) if params.nonEmpty =>
tree.removeAttachment(desugar.PolyFunctionApply)
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
TypeTree(tpe).withSpan(tree.span) -> tpe
case _ => tree -> pt
}

/** Interpolate and simplify the type of the given tree. */
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type =
if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
|| tree.isDef // ... unless tree is a definition
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt)
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
|| tree1.isDef // ... unless tree is a definition
then
interpolateTypeVars(tree, pt, locked)
val simplified = tree.tpe.simplified
if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743
interpolateTypeVars(tree1, pt1, locked)
val simplified = tree1.tpe.simplified
if !MatchType.thatReducesUsingGadt(tree1.tpe) then // needs a GADT cast. i15743
tree.overwriteType(simplified)
tree
tree1

protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked
println(i"make contextual function $tree / $pt")
val paramNamesOrNil = pt match
case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames
case _ => Nil
Expand Down
9 changes: 6 additions & 3 deletions tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
// type Comparer2 = [X: Ord] => Cmp[X]
// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

// type CmpWeak[X] = (x: X, y: X) => Boolean
// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X]
// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
type CmpWeak[X] = X => Boolean
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
val less4_0: [X: Ord] => X => X => Boolean =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
val less4: Comparer2Weak =
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0

val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

Expand Down

0 comments on commit 42d914e

Please sign in to comment.