Skip to content

Commit

Permalink
Extract common code
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <nghiat@nvidia.com>
  • Loading branch information
ttnghia committed Nov 5, 2023
1 parent 14f230f commit 1b464ec
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ object GpuParquetFileFormat {
if (schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for int96 timestamps is not supported")
}
case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " +
"(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')")
case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other))
}

SparkShimImpl.parquetRebaseWrite(sqlConf) match {
Expand All @@ -139,8 +138,7 @@ object GpuParquetFileFormat {
s"session: ${SQLConf.get.sessionLocalTimeZone}). " +
" Set both of the timezones to UTC to enable LEGACY rebase support.")
}
case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " +
"(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')")
case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other))
}

if (meta.canThisBeReplaced) {
Expand Down Expand Up @@ -191,9 +189,11 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
val conf = ContextUtil.getConfiguration(job)

val outputTimestampType = sqlConf.parquetOutputTimestampType
val dateTimeRebaseMode = sparkSession.sqlContext.getConf(SparkShimImpl.parquetRebaseWriteKey)
val dateTimeRebaseMode = DateTimeRebaseUtils.getRebaseModeFromName(
sparkSession.sqlContext.getConf(SparkShimImpl.parquetRebaseWriteKey))
val timestampRebaseMode = if (outputTimestampType.equals(ParquetOutputTimestampType.INT96)) {
sparkSession.sqlContext.getConf(SparkShimImpl.int96ParquetRebaseWriteKey)
DateTimeRebaseUtils.getRebaseModeFromName(
sparkSession.sqlContext.getConf(SparkShimImpl.int96ParquetRebaseWriteKey))
} else {
dateTimeRebaseMode
}
Expand Down Expand Up @@ -300,19 +300,19 @@ class GpuParquetWriter(
dataSchema: StructType,
compressionType: CompressionType,
outputTimestampType: String,
dateRebaseMode: String,
timestampRebaseMode: String,
dateRebaseMode: DateTimeRebaseMode,
timestampRebaseMode: DateTimeRebaseMode,
context: TaskAttemptContext,
parquetFieldIdEnabled: Boolean)
extends ColumnarOutputWriter(context, dataSchema, "Parquet", true) {
override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = {
val cols = GpuColumnVector.extractBases(batch)
cols.foreach { col =>
if (dateRebaseMode.equals("EXCEPTION") &&
if (dateRebaseMode == DateTimeRebaseException &&
DateTimeRebaseUtils.isDateRebaseNeededInWrite(col)) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
}
else if (timestampRebaseMode.equals("EXCEPTION") &&
else if (timestampRebaseMode == DateTimeRebaseException &&
DateTimeRebaseUtils.isTimeRebaseNeededInWrite(col)) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
}
Expand All @@ -333,14 +333,14 @@ class GpuParquetWriter(
ColumnCastUtil.deepTransform(cv, Some(dt)) {
case (cv, _) if cv.getType.isTimestampType =>
if(cv.getType == DType.TIMESTAMP_DAYS) {
if (dateRebaseMode.equals("LEGACY")) {
if (dateRebaseMode == DateTimeRebaseLegacy) {
DateTimeRebase.rebaseGregorianToJulian(cv)
} else {
cv.copyToColumnVector()
}
} else { /* timestamp */
val typeMillis = ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString
if (timestampRebaseMode.equals("LEGACY")) {
if (timestampRebaseMode == DateTimeRebaseLegacy) {
val rebasedTimestampAsMicros = if(cv.getType == DType.TIMESTAMP_MICROSECONDS) {
DateTimeRebase.rebaseGregorianToJulian(cv)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ object GpuParquetScan {
if (schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported")
}
case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " +
"(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')")
case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other))
}

sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match {
Expand All @@ -228,8 +227,7 @@ object GpuParquetScan {
if (schemaHasDates || schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported")
}
case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " +
"(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')")
case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ case object DateTimeRebaseLegacy extends DateTimeRebaseMode("LEGACY")
case object DateTimeRebaseCorrected extends DateTimeRebaseMode("CORRECTED")

object DateTimeRebaseUtils {
def invalidRebaseModeMessage(name: String): String =
s"Invalid datetime rebase mode: $name (must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')"

def getRebaseModeFromName(name: String): DateTimeRebaseMode = name match {
case DateTimeRebaseException.value => DateTimeRebaseException
case DateTimeRebaseLegacy.value => DateTimeRebaseLegacy
case DateTimeRebaseCorrected.value => DateTimeRebaseCorrected
case _ => throw new IllegalArgumentException(invalidRebaseModeMessage(name))
}

// Copied from Spark
private val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version"
private val SPARK_LEGACY_DATETIME_METADATA_KEY = "org.apache.spark.legacyDateTime"
Expand All @@ -66,14 +76,7 @@ object DateTimeRebaseUtils {
} else {
DateTimeRebaseCorrected
}
}.getOrElse(modeByConfig match {
case DateTimeRebaseException.value => DateTimeRebaseException
case DateTimeRebaseLegacy.value => DateTimeRebaseLegacy
case DateTimeRebaseCorrected.value => DateTimeRebaseCorrected
case _ => throw new IllegalArgumentException(
s"Invalid datetime rebase mode from config: $modeByConfig " +
"(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')")
})
}.getOrElse(getRebaseModeFromName(modeByConfig))

// Check the timezone of the file if the mode is LEGACY.
if (mode == DateTimeRebaseLegacy) {
Expand Down

0 comments on commit 1b464ec

Please sign in to comment.