diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 861decbda54d..04dc1a8b8a30 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -517,6 +517,7 @@ class Definitions { def ListType: TypeRef = ListClass.typeRef @tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List") @tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply) + def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List) @tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil") def NilType: TermRef = NilModule.termRef @tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::") @@ -531,8 +532,11 @@ class Definitions { List(AnyType), EmptyScope) @tu lazy val SingletonType: TypeRef = SingletonClass.typeRef - @tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq") - @tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq") + @tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq") + @tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq") + @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply) + def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq) def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass @tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply) @tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head) @@ -540,8 +544,6 @@ class Definitions { @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) - @tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq") - @tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply) @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala index 652959b83227..8021b49a795f 100644 --- a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -1,15 +1,16 @@ -package dotty.tools.dotc +package dotty.tools +package dotc package transform import core._ import MegaPhase._ import Contexts._ +import Decorators.* import Symbols._ import Flags._ import StdNames._ -import dotty.tools.dotc.ast.tpd - - +import ast.tpd +import reporting.trace /** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode. * @@ -22,49 +23,66 @@ class ArrayApply extends MiniPhase { override def description: String = ArrayApply.description - private var transformListApplyLimit = 8 + private val transformListApplyLimit = 8 - private def reducingTransformListApply[A](depth: Int)(body: => A): A = { - val saved = transformListApplyLimit - transformListApplyLimit -= depth - try body - finally transformListApplyLimit = saved - } + override def transformTypeApply(tree: TypeApply)(using Context): Tree = + stripCast(tree) match + case app: Apply if isConsChain(app) => app + case _ => tree - override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree = + override def transformApply(tree: Apply)(using Context): Tree = if isArrayModuleApply(tree.symbol) then tree.args match - case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil + case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) => seqLit - case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil + case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) => - tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt) + JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt) case _ => tree - else if isListOrSeqModuleApply(tree.symbol) then + else if isSeqApply(tree) then tree.args match // (a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types - case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil + case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) && rest.elems.lengthIs < transformListApplyLimit => - rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) => - tpd.New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) + rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) => + New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc)) case _ => tree else tree + private def isConsChain(tree: Tree)(using Context): Boolean = tree match + case Apply(Select(New(tt), nme.CONSTRUCTOR), List(_, arg)) => + tt.symbol == defn.ConsClass && (arg.symbol == defn.NilModule || isConsChain(arg)) + case _ => false + private def isArrayModuleApply(sym: Symbol)(using Context): Boolean = sym.name == nme.apply && (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension))) - private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean = - sym == defn.ListModule_apply || sym == defn.SeqModule_apply + private def isListApply(tree: Tree)(using Context): Boolean = + (tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.ListModule + || sym == defn.ListModuleAlias + case _ => false + + private def isSeqApply(tree: Tree)(using Context): Boolean = + isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match + case Select(qual, _) => + val sym = qual.symbol + sym == defn.SeqModule + || sym == defn.SeqModuleAlias + || sym == defn.CollectionSeqType.symbol.companionModule + case _ => false /** Only optimize when classtag if it is one of * - `ClassTag.apply(classOf[XYZ])` diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala index a2d37b8399e5..264571d6b6e6 100644 --- a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -1,4 +1,5 @@ -package dotty.tools.backend.jvm +package dotty.tools +package backend.jvm import org.junit.Test import org.junit.Assert._ @@ -161,15 +162,50 @@ class ArrayApplyOptTest extends DottyBytecodeTest { } @Test def testListApplyAvoidsIntermediateArray = { - val source = + checkApplyAvoidsIntermediateArray("List"): """ |class Foo { | def meth1: List[String] = List("1", "2", "3") | def meth2: List[String] = - | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray = { + checkApplyAvoidsIntermediateArray("Seq"): + """ + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = + | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] |} """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray2 = { + checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"): + """import scala.collection.immutable.Seq + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = + | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + |} + """.stripMargin + } + + @Test def testSeqApplyAvoidsIntermediateArray3 = { + checkApplyAvoidsIntermediateArray("scala.collection.Seq"): + """import scala.collection.Seq + |class Foo { + | def meth1: Seq[String] = Seq("1", "2", "3") + | def meth2: Seq[String] = + | new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]] + |} + """.stripMargin + } + def checkApplyAvoidsIntermediateArray(name: String)(source: String) = { checkBCode(source) { dir => val clsIn = dir.lookupName("Foo.class", directory = false).input val clsNode = loadClassNode(clsIn) @@ -180,7 +216,7 @@ class ArrayApplyOptTest extends DottyBytecodeTest { val instructions2 = instructionsFromMethod(meth2) assert(instructions1 == instructions2, - "the List.apply method " + + s"the $name.apply method\n" + diffInstructions(instructions1, instructions2)) } } diff --git a/tests/run/list-apply-eval.scala b/tests/run/list-apply-eval.scala index 4e25444689cc..bd60cd50c8db 100644 --- a/tests/run/list-apply-eval.scala +++ b/tests/run/list-apply-eval.scala @@ -19,3 +19,6 @@ object Test: val emptyList = List[Int]() assert(emptyList == Nil) + + // just assert it doesn't throw CCE to List + val queue = scala.collection.mutable.Queue[String]()