Skip to content

Commit

Permalink
Fixup List optimisation improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Nov 9, 2023
1 parent 3ed582c commit f940d92
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 29 deletions.
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ class Definitions {
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
Expand All @@ -531,17 +532,18 @@ class Definitions {
List(AnyType), EmptyScope)
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef

@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
@tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply)
def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq)
def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
@tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq")
@tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply)


@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
Expand Down
60 changes: 39 additions & 21 deletions compiler/src/dotty/tools/dotc/transform/ArrayApply.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package dotty.tools.dotc
package dotty.tools
package dotc
package transform

import core._
import MegaPhase._
import Contexts._
import Decorators.*
import Symbols._
import Flags._
import StdNames._
import dotty.tools.dotc.ast.tpd


import ast.tpd
import reporting.trace

/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
*
Expand All @@ -22,49 +23,66 @@ class ArrayApply extends MiniPhase {

override def description: String = ArrayApply.description

private var transformListApplyLimit = 8
private val transformListApplyLimit = 8

private def reducingTransformListApply[A](depth: Int)(body: => A): A = {
val saved = transformListApplyLimit
transformListApplyLimit -= depth
try body
finally transformListApplyLimit = saved
}
override def transformTypeApply(tree: TypeApply)(using Context): Tree =
stripCast(tree) match
case app: Apply if isConsChain(app) => app
case _ => tree

override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
override def transformApply(tree: Apply)(using Context): Tree =
if isArrayModuleApply(tree.symbol) then
tree.args match
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
seqLit

case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)

case _ =>
tree

else if isListOrSeqModuleApply(tree.symbol) then
else if isSeqApply(tree) then
tree.args match
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) &&
rest.elems.lengthIs < transformListApplyLimit =>
rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) =>
tpd.New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))

case _ =>
tree

else tree

private def isConsChain(tree: Tree)(using Context): Boolean = tree match
case Apply(Select(New(tt), nme.CONSTRUCTOR), List(_, arg)) =>
tt.symbol == defn.ConsClass && (arg.symbol == defn.NilModule || isConsChain(arg))
case _ => false

private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
sym.name == nme.apply
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))

private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean =
sym == defn.ListModule_apply || sym == defn.SeqModule_apply
private def isListApply(tree: Tree)(using Context): Boolean =
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.ListModule
|| sym == defn.ListModuleAlias
case _ => false

private def isSeqApply(tree: Tree)(using Context): Boolean =
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.SeqModule
|| sym == defn.SeqModuleAlias
|| sym == defn.CollectionSeqType.symbol.companionModule
case _ => false

/** Only optimize when classtag if it is one of
* - `ClassTag.apply(classOf[XYZ])`
Expand Down
44 changes: 40 additions & 4 deletions compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package dotty.tools.backend.jvm
package dotty.tools
package backend.jvm

import org.junit.Test
import org.junit.Assert._
Expand Down Expand Up @@ -161,15 +162,50 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
}

@Test def testListApplyAvoidsIntermediateArray = {
val source =
checkApplyAvoidsIntermediateArray("List"):
"""
|class Foo {
| def meth1: List[String] = List("1", "2", "3")
| def meth2: List[String] =
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("Seq"):
"""
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] =
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray2 = {
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
"""import scala.collection.immutable.Seq
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] =
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray3 = {
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
"""import scala.collection.Seq
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] =
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
|}
""".stripMargin
}

def checkApplyAvoidsIntermediateArray(name: String)(source: String) = {
checkBCode(source) { dir =>
val clsIn = dir.lookupName("Foo.class", directory = false).input
val clsNode = loadClassNode(clsIn)
Expand All @@ -180,7 +216,7 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
"the List.apply method " +
s"the $name.apply method\n" +
diffInstructions(instructions1, instructions2))
}
}
Expand Down
3 changes: 3 additions & 0 deletions tests/run/list-apply-eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ object Test:

val emptyList = List[Int]()
assert(emptyList == Nil)

// just assert it doesn't throw CCE to List
val queue = scala.collection.mutable.Queue[String]()

0 comments on commit f940d92

Please sign in to comment.