diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 4ca062c0b5adf..b6909b3386b71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} @@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] { val wordVectors = instance.wordVectors.getVectors val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString + val bufferSizeInBytes = Utils.byteStringAsBytes( + sc.conf.get("spark.kryoserializer.buffer.max", "64m")) + val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( + bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) sparkSession.createDataFrame(dataSeq) - .repartition(calculateNumberOfPartitions) + .repartition(numPartitions) .write .parquet(dataPath) } + } - def calculateNumberOfPartitions(): Int = { - val floatSize = 4 + private[feature] + object Word2VecModelWriter { + /** + * Calculate the number of partitions to use in saving the model. + * [SPARK-11994] - We want to partition the model in partitions smaller than + * spark.kryoserializer.buffer.max + * @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max + * @param numWords Vocab size + * @param vectorSize Vector length for each word + */ + def calculateNumberOfPartitions( + bufferSizeInBytes: Long, + numWords: Int, + vectorSize: Int): Int = { + val floatSize = 4L // Use Long to help avoid overflow val averageWordSize = 15 - // [SPARK-11994] - We want to partition the model in partitions smaller than - // spark.kryoserializer.buffer.max - val bufferSizeInBytes = Utils.byteStringAsBytes( - sc.conf.get("spark.kryoserializer.buffer.max", "64m")) // Calculate the approximate size of the model. // Assuming an average word size of 15 bytes, the formula is: // (floatSize * vectorSize + 15) * numWords - val numWords = instance.wordVectors.wordIndex.size - val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords - ((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt + val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords + val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1 + require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " + + s"partitions to save this model, which is too large. Try increasing " + + s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.") + numPartitions.toInt } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a6a1c2b4f32bd..6183606a7b2ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.util.Utils class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } + test("Word2Vec read/write numPartitions calculation") { + val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5) + assert(smallModelNumPartitions === 1) + val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000) + assert(largeModelNumPartitions > 1) + } + test("Word2Vec read/write") { val t = new Word2Vec() .setInputCol("myInputCol")