Skip to content

Commit

Permalink
[SPARK-21050][ML] Word2vec persistence overflow bug fix
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The method calculateNumberOfPartitions() uses Int, not Long (unlike the MLlib version), so it is very easily to have an overflow in calculating the number of partitions for ML persistence.

This modifies the calculations to use Long.

## How was this patch tested?

New unit test.  I verified that the test fails before this patch.

Author: Joseph K. Bradley <joseph@databricks.com>

Closes apache#18265 from jkbradley/word2vec-save-fix.
  • Loading branch information
jkbradley committed Jun 12, 2017
1 parent b1436c7 commit ff318c0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
38 changes: 28 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
Expand Down Expand Up @@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
val wordVectors = instance.wordVectors.getVectors
val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) }
val dataPath = new Path(path, "data").toString
val bufferSizeInBytes = Utils.byteStringAsBytes(
sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions(
bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize)
sparkSession.createDataFrame(dataSeq)
.repartition(calculateNumberOfPartitions)
.repartition(numPartitions)
.write
.parquet(dataPath)
}
}

def calculateNumberOfPartitions(): Int = {
val floatSize = 4
private[feature]
object Word2VecModelWriter {
/**
* Calculate the number of partitions to use in saving the model.
* [SPARK-11994] - We want to partition the model in partitions smaller than
* spark.kryoserializer.buffer.max
* @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max
* @param numWords Vocab size
* @param vectorSize Vector length for each word
*/
def calculateNumberOfPartitions(
bufferSizeInBytes: Long,
numWords: Int,
vectorSize: Int): Int = {
val floatSize = 4L // Use Long to help avoid overflow
val averageWordSize = 15
// [SPARK-11994] - We want to partition the model in partitions smaller than
// spark.kryoserializer.buffer.max
val bufferSizeInBytes = Utils.byteStringAsBytes(
sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
// Calculate the approximate size of the model.
// Assuming an average word size of 15 bytes, the formula is:
// (floatSize * vectorSize + 15) * numWords
val numWords = instance.wordVectors.wordIndex.size
val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords
((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt
val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords
val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1
require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " +
s"partitions to save this model, which is too large. Try increasing " +
s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.")
numPartitions.toInt
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.util.Utils

class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

Expand Down Expand Up @@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
}

test("Word2Vec read/write numPartitions calculation") {
val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5)
assert(smallModelNumPartitions === 1)
val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000)
assert(largeModelNumPartitions > 1)
}

test("Word2Vec read/write") {
val t = new Word2Vec()
.setInputCol("myInputCol")
Expand Down

0 comments on commit ff318c0

Please sign in to comment.