Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List(...) optimization to avoid intermediate array #17166

Merged
merged 7 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,16 @@ class Definitions {
methodNames.map(getWrapVarargsArrayModule.requiredMethod(_))
})

@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")

@tu lazy val SingletonClass: ClassSymbol =
// needed as a synthetic class because Scala 2.x refers to it in classfiles
Expand All @@ -534,16 +536,18 @@ class Definitions {
List(AnyType), EmptyScope)
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef

@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
@tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply)
def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq)
def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")


@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
Expand Down
71 changes: 56 additions & 15 deletions compiler/src/dotty/tools/dotc/transform/ArrayApply.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package dotty.tools.dotc
package dotty.tools
package dotc
package transform

import core.*
import ast.tpd
import core.*, Contexts.*, Decorators.*, Symbols.*, Flags.*, StdNames.*
import reporting.trace
import util.Property
import MegaPhase.*
import Contexts.*
import Symbols.*
import Flags.*
import StdNames.*
import dotty.tools.dotc.ast.tpd



/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
*
Expand All @@ -22,27 +19,71 @@ class ArrayApply extends MiniPhase {

override def description: String = ArrayApply.description

override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
private val TransformListApplyBudgetKey = new Property.Key[Int]
private def transformListApplyBudget(using Context) = ctx.property(TransformListApplyBudgetKey).getOrElse(8)

dwijnand marked this conversation as resolved.
Show resolved Hide resolved
override def prepareForApply(tree: Apply)(using Context): Context =
if isSeqApply(tree) then
val args = seqApplyArgsOrNull(tree)
if args != null then
ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - args.elems.length)
else ctx
else ctx

override def transformApply(tree: Apply)(using Context): Tree =
if isArrayModuleApply(tree.symbol) then
tree.args match {
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
tree.args match
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
seqLit

case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)

case _ =>
tree
}

else if isSeqApply(tree) then
val args = seqApplyArgsOrNull(tree)
if args != null && (transformListApplyBudget > 0 || args.elems.isEmpty) then
val consed = args.elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
consed.cast(tree.tpe)
else tree

else tree

private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
sym.name == nme.apply
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))

private def isListApply(tree: Tree)(using Context): Boolean =
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.ListModule
|| sym == defn.ListModuleAlias
case _ => false

private def isSeqApply(tree: Tree)(using Context): Boolean =
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.SeqModule
|| sym == defn.SeqModuleAlias
|| sym == defn.CollectionSeqType.symbol.companionModule
case _ => false

private def seqApplyArgsOrNull(tree: Apply)(using Context): JavaSeqLiteral | Null =
// assumes isSeqApply(tree)
dwijnand marked this conversation as resolved.
Show resolved Hide resolved
tree.args match
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) =>
rest
case _ => null

/** Only optimize when classtag if it is one of
* - `ClassTag.apply(classOf[XYZ])`
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``
Expand Down
78 changes: 77 additions & 1 deletion compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package dotty.tools.backend.jvm
package dotty.tools
package backend.jvm

import org.junit.Test
import org.junit.Assert._
Expand Down Expand Up @@ -160,4 +161,79 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
}
}

@Test def testListApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("List"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: List[String] = List("1", "2", "3")
| def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("Seq"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray2 = {
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
"""import scala.collection.immutable.{ ::, Seq, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray3 = {
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
"""import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

def checkApplyAvoidsIntermediateArray(name: String)(source: String) = {
dwijnand marked this conversation as resolved.
Show resolved Hide resolved
checkBCode(source) { dir =>
val clsIn = dir.lookupName("Foo.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth1 = getMethod(clsNode, "meth1")
val meth2 = getMethod(clsNode, "meth2")

val instructions1 = instructionsFromMethod(meth1) match
case instr :+ TypeOp(CHECKCAST, _) :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) =>
instr :+ ret
case instr :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) =>
// List.apply[?A] doesn't, strictly, return List[?A],
// because it cascades to its definition on IterableFactory
// where it returns CC[A]. The erasure of that is Object,
// which is why Erasure's Typer adds a cast to compensate.
// If we drop that cast while optimising (because using
// the constructor for :: doesn't require the cast like
// List.apply did) then then cons construction chain will
// be typed as ::.
// Unfortunately the LUB of :: and Nil.type is Product
// instead of List, so a cast remains necessary,
// across whatever causes the lub, like `if` or `try` branches.
// Therefore if we dropping the cast may cause a needed cast
// to be necessary, we shouldn't drop the cast,
// which was only motivated by the assert here.
instr :+ ret
case instr => instr
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
s"the $name.apply method\n" +
diffInstructions(instructions1, instructions2))
}
}

}
89 changes: 89 additions & 0 deletions tests/run/list-apply-eval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
object Test:

var counter = 0

def next =
counter += 1
counter.toString

def main(args: Array[String]): Unit =
//List.apply is subject to an optimisation in cleanup
//ensure that the arguments are evaluated in the currect order
// Rewritten to:
// val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil)));
val myList = List(next, next, next)
assert(myList == List("1", "2", "3"), myList)

val mySeq = Seq(next, next, next)
assert(mySeq == Seq("4", "5", "6"), mySeq)

val emptyList = List[Int]()
assert(emptyList == Nil)

// just assert it doesn't throw CCE to List
val queue = scala.collection.mutable.Queue[String]()

// test for the cast instruction described in checkApplyAvoidsIntermediateArray
def lub(b: Boolean): List[(String, String)] =
if b then List(("foo", "bar")) else Nil

// from minimising CI failure in oslib
// again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce)
def lub2(b: Boolean): Unit =
Seq(1) ++ (if (b) Seq(2) else Nil)

// Examples of arity and nesting arity
// to find the thresholds and reproduce the behaviour of nsc
// tested manually, comparing -Xprint across compilers (ran out of time)
def examples(): Unit =
val max1 = List[Object]("1", "2", "3", "4", "5", "6", "7") // 7 cons w/ 7 string heads + nil
val max2 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) // 7 cons w/ 6 string heads + 1 nil head + nil
val max3 = List[Object]("1", "2", "3", "4", "5", List[Object]("6"))
val max4 = List[Object]("1", "2", "3", "4", List[Object]("5", "6"))

val over1 = List[Object]("1", "2", "3", "4", "5", "6", "7", "8") // wrap 8-sized array
val over2 = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) // wrap 8-sized array
val over3 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) // wrap 1-sized array with 7
val over4 = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) // wrap 2

val max5 =
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
)))))))) // 7 cons + 1 nil

val over5 =
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object](
List[Object]( List[Object]()
)))))))) // 7 cons + 1-sized array wrapping nil

val max6 =
List[Object]( // ::(
"1", "2", List[Object]( // 1, ::(2, ::(::(
"3", "4", List[Object]( // 3, ::(4, ::(::(
List[Object]() // Nil, Nil
) // ), Nil))
) // ), Nil))
) // )
// 7 cons + 4 string heads + 4 nils for nested lists

val max7 =
List[Object]( // ::(
"1", "2", List[Object]( // 1, ::(2, ::(::(
"3", "4", List[Object]( // 3, ::(4, ::(::(
"5" // 5, Nil
) // ), Nil))
) // ), Nil))
) // )
// 7 cons + 5 string heads + 3 nils for nested lists
Loading