Skip to content

Commit

Permalink
fix mllib test and warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Apr 6, 2015
1 parent ef1fc2f commit ad7c374
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.api.python

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD

Expand All @@ -31,10 +32,14 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization
predict(SerDe.asTupleRDD(userAndProducts.rdd))

def getUserFeatures: RDD[Array[Any]] = {
SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
SerDe.fromTuple2RDD(userFeatures.map {
case (user, feature) => (user, Vectors.dense(feature))
}.asInstanceOf[RDD[(Any, Any)]])
}

def getProductFeatures: RDD[Array[Any]] = {
SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
SerDe.fromTuple2RDD(productFeatures.map {
case (product, feature) => (product, Vectors.dense(feature))
}.asInstanceOf[RDD[(Any, Any)]])
}
}
8 changes: 4 additions & 4 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,11 @@ def test_serialization(self):
def test_infer_schema(self):
sqlCtx = SQLContext(self.sc)
rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
srdd = sqlCtx.inferSchema(rdd)
schema = srdd.schema
df = rdd.toDF()
schema = df.schema
field = [f for f in schema.fields if f.name == "features"][0]
self.assertEqual(field.dataType, self.udt)
vectors = srdd.map(lambda p: p.features).collect()
vectors = df.map(lambda p: p.features).collect()
self.assertEqual(len(vectors), 2)
for v in vectors:
if isinstance(v, SparseVector):
Expand Down Expand Up @@ -639,7 +639,7 @@ def test_right_number_of_results(self):

class SerDeTest(PySparkTestCase):
def test_to_java_object_rdd(self): # SPARK-6660
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
self.assertEqual(_to_java_object_rdd(data).count(), 10)


Expand Down

0 comments on commit ad7c374

Please sign in to comment.