From 6078710a7de36e0e924692cd261ca7229eabd0c9 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 27 Feb 2024 18:28:27 +0100 Subject: [PATCH] Specialized retained inline FunctionN apply methods Fixes #19724 [Cherry-picked 22020f37f96b64042c3b85371792220f4f957fcb] --- .../dotc/transform/SpecializeFunctions.scala | 25 ++++++++++++++++--- tests/pos/i19724.scala | 5 ++++ tests/run/i19724.scala | 18 +++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 tests/pos/i19724.scala create mode 100644 tests/run/i19724.scala diff --git a/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala b/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala index 094d6024eb4e..43aef6279cec 100644 --- a/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala +++ b/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala @@ -3,7 +3,7 @@ package transform import ast.Trees.*, ast.tpd, core.* import Contexts.*, Types.*, Decorators.*, Symbols.*, DenotTransformers.* -import SymDenotations.*, Scopes.*, StdNames.*, NameOps.*, Names.* +import SymDenotations.*, Scopes.*, StdNames.*, NameOps.*, Names.*, NameKinds.* import MegaPhase.MiniPhase @@ -25,7 +25,24 @@ class SpecializeFunctions extends MiniPhase { /** Create forwarders from the generic applys to the specialized ones. */ override def transformDefDef(ddef: DefDef)(using Context) = { - if ddef.name != nme.apply + // Note on special case for inline `apply`s: + // `apply` and `apply$retainedBody` are specialized in this transformation. + // `apply$retainedBody` have the name kind `BodyRetainerName`, these contain + // the runtime implementation of an inline `apply` that implements (or overrides) + // the `FunctionN.apply` method. The inline method is not specialized, it will + // be replaced with the implementation of `apply$retainedBody`. The following code + // inline def apply(x: Int): Double = x.toDouble:Double + // private def apply$retainedBody(x: Int): Double = x.toDouble:Double + // in is transformed into + // inline def apply(x: Int): Double = x.toDouble:Double + // private def apply$retainedBody(x: Int): Double = this.apply$mcDI$sp(x) + // def apply$mcDI$sp(v: Int): Double = x.toDouble:Double + // after erasure it will become + // def apply(v: Int): Double = this.apply$mcDI$sp(v) // from apply$retainedBody + // def apply$mcDI$sp(v: Int): Double = v.toDouble():Double + // def apply(v1: Object): Object = Double.box(this.apply(Int.unbox(v1))) // erasure bridge + + if ddef.name.asTermName.exclude(BodyRetainerName) != nme.apply || ddef.termParamss.length != 1 || ddef.termParamss.head.length > 2 || !ctx.owner.isClass @@ -44,12 +61,12 @@ class SpecializeFunctions extends MiniPhase { defn.isSpecializableFunction(cls, paramTypes, retType) } - if (sym.is(Flags.Deferred) || !isSpecializable) return ddef + if (sym.is(Flags.Deferred) || sym.is(Flags.Inline) || !isSpecializable) return ddef val specializedApply = newSymbol( cls, specName.nn, - sym.flags | Flags.Synthetic, + (sym.flags | Flags.Synthetic) &~ Flags.Private, // Private flag can be set if the name is a BodyRetainerName sym.info ).entered diff --git a/tests/pos/i19724.scala b/tests/pos/i19724.scala new file mode 100644 index 000000000000..776cf9167890 --- /dev/null +++ b/tests/pos/i19724.scala @@ -0,0 +1,5 @@ +object repro: + abstract class Mapper[A, B] extends (A => B) + + given Mapper[Int, Double] with + inline def apply(v: Int): Double = v.toDouble diff --git a/tests/run/i19724.scala b/tests/run/i19724.scala new file mode 100644 index 000000000000..0ed6fcb94c57 --- /dev/null +++ b/tests/run/i19724.scala @@ -0,0 +1,18 @@ +class F0 extends (() => Double): + inline def apply(): Double = 1.toDouble + +class F1 extends (Int => Double): + inline def apply(v: Int): Double = v.toDouble + +class F2 extends ((Int, Int) => Double): + inline def apply(v1: Int, v2: Int): Double = (v1 + v2).toDouble + +@main def Test = + val f0: (() => Double) = new F0 + assert(f0() == 1.0) + + val f1: (Int => Double) = new F1 + assert(f1(3) == 3.0) + + val f2: ((Int, Int) => Double) = new F2 + assert(f2(3, 2) == 5.0)