Skip to content

Commit

Permalink
add support for DateType
Browse files Browse the repository at this point in the history
Hi,
based on this issue databricks#67
I create this pull request

Author: Nihed MBAREK <nihedmm@gmail.com>
Author: vlyubin <vlyubin@gmail.com>
Author: nihed <nihedmm@gmail.com>

Closes databricks#124 from nihed/master.
  • Loading branch information
nihed authored and vlyubin committed Feb 16, 2017
1 parent 4a3284b commit c19f01a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.databricks.spark.avro
import java.io.{IOException, OutputStream}
import java.nio.ByteBuffer
import java.sql.Timestamp
import java.sql.Date
import java.util.HashMap

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -90,6 +91,8 @@ private[avro] class AvroOutputWriter(
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
case TimestampType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Timestamp].getTime
case DateType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Date].getTime
case ArrayType(elementType, _) =>
val elementConverter = createConverterToAvro(elementType, structName, recordNamespace)
(item: Any) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ object SchemaConverters {
case BinaryType => schemaBuilder.bytesType()
case BooleanType => schemaBuilder.booleanType()
case TimestampType => schemaBuilder.longType()
case DateType => schemaBuilder.longType()

case ArrayType(elementType, _) =>
val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
Expand Down Expand Up @@ -371,6 +372,7 @@ object SchemaConverters {
case BinaryType => newFieldBuilder.bytesType()
case BooleanType => newFieldBuilder.booleanType()
case TimestampType => newFieldBuilder.longType()
case DateType => newFieldBuilder.longType()

case ArrayType(elementType, _) =>
val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
Expand Down
33 changes: 26 additions & 7 deletions src/test/scala/com/databricks/spark/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,21 @@ package com.databricks.spark.avro
import java.io._
import java.nio.ByteBuffer
import java.nio.file.Files
import java.sql.Timestamp
import java.util.UUID
import java.sql.{Date, Timestamp}
import java.util.{TimeZone, UUID}

import scala.collection.JavaConversions._

import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException
import org.apache.avro.Schema
import org.apache.avro.Schema.{Field, Type}
import org.apache.avro.file.DataFileWriter
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.commons.io.FileUtils

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException

class AvroSuite extends FunSuite with BeforeAndAfterAll {
val episodesFile = "src/test/resources/episodes.avro"
Expand Down Expand Up @@ -297,6 +296,26 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll {
}
}

test("Date field type") {
TestUtils.withTempDir { dir =>
val schema = StructType(Seq(
StructField("float", FloatType, true),
StructField("date", DateType, true)
))
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
val rdd = spark.sparkContext.parallelize(Seq(
Row(1f, null),
Row(2f, new Date(1451948400000L)),
Row(3f, new Date(1460066400500L))
))
val df = spark.createDataFrame(rdd, schema)
df.write.avro(dir.toString)
assert(spark.read.avro(dir.toString).count == rdd.count)
assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet ==
Array(null, 1451865600000L, 1459987200000L).toSet)
}
}

test("Array data types") {
TestUtils.withTempDir { dir =>
val testSchema = StructType(Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

package com.databricks.spark.avro

import java.sql.Date
import java.util.concurrent.TimeUnit

import scala.collection.JavaConversions._
import scala.util.Random

import com.google.common.io.Files
import org.apache.commons.io.FileUtils

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql._
import org.apache.spark.sql.types._

/**
Expand All @@ -40,6 +40,7 @@ object AvroWriteBenchmark {
val testSchema = StructType(Seq(
StructField("StringField", StringType, false),
StructField("IntField", IntegerType, true),
StructField("dateField", DateType, true),
StructField("DoubleField", DoubleType, false),
StructField("DecimalField", DecimalType(10, 10), true),
StructField("ArrayField", ArrayType(BooleanType), false),
Expand All @@ -48,7 +49,7 @@ object AvroWriteBenchmark {

private def generateRandomRow(): Row = {
val rand = new Random()
Row(rand.nextString(defaultSize), rand.nextInt(), rand.nextDouble(), rand.nextDouble(),
Row(rand.nextString(defaultSize), rand.nextInt(), new Date(rand.nextLong()) ,rand.nextDouble(), rand.nextDouble(),
TestUtils.generateRandomArray(rand, defaultSize).toSeq,
TestUtils.generateRandomMap(rand, defaultSize).toMap, Row(rand.nextInt()))
}
Expand Down

0 comments on commit c19f01a

Please sign in to comment.