From 1339df54b2a9c01a2ff9e5b6fd2297f0f5012528 Mon Sep 17 00:00:00 2001 From: Kevin H Wilson Date: Mon, 30 Sep 2024 10:04:31 -0400 Subject: [PATCH] extract common code --- .../arrow/compute/kernels/aggregate_basic.cc | 12 +---------- .../compute/kernels/aggregate_basic.inc.cc | 21 ++----------------- .../arrow/compute/kernels/codegen_internal.cc | 11 ++++++++++ .../arrow/compute/kernels/codegen_internal.h | 3 +++ 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 042eca4747520..00be838550b9d 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -366,17 +366,7 @@ struct ProductImpl : public ScalarAggregator { Status Finalize(KernelContext*, Datum* out) override { std::shared_ptr out_type_; - if (out_type->id() == Type::DECIMAL128) { - auto cast_type = checked_pointer_cast(this->out_type); - ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, - cast_type->scale())); - } else if (out_type->id() == Type::DECIMAL256) { - auto cast_type = checked_pointer_cast(this->out_type); - ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, - cast_type->scale())); - } else { - out_type_ = out_type; - } + ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type)); if ((!options.skip_nulls && this->nulls_observed) || (this->count < options.min_count)) { diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc index d06f83910bf8a..4fa32edd50929 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc @@ -93,17 +93,7 @@ struct SumImpl : public ScalarAggregator { Status Finalize(KernelContext*, Datum* out) override { std::shared_ptr out_type_; - if (out_type->id() == Type::DECIMAL128) { - auto cast_type = checked_pointer_cast(this->out_type); - ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, - cast_type->scale())); - } else if (out_type->id() == Type::DECIMAL256) { - auto cast_type = checked_pointer_cast(this->out_type); - ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, - cast_type->scale())); - } else { - out_type_ = out_type; - } + ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type)); if ((!options.skip_nulls && this->nulls_observed) || (this->count < options.min_count)) { @@ -233,14 +223,7 @@ struct MeanImpl> template Status FinalizeImpl(Datum* out) { std::shared_ptr out_type_; - auto decimal_type = checked_pointer_cast(this->out_type); - if (decimal_type->id() == Type::DECIMAL128) { - ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, - decimal_type->scale())); - } else { - ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, - decimal_type->scale())); - } + ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type)); if ((!options.skip_nulls && this->nulls_observed) || (this->count < options.min_count) || (this->count == 0)) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index fa29c622e7241..fe46fc9da031b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -547,6 +547,17 @@ Status CastDecimalArgs(TypeHolder* begin, size_t count) { return Status::OK(); } +Result> WidenDecimalToMaxPrecision(std::shared_ptr type) { + if (type->id() == Type::DECIMAL128) { + auto cast_type = checked_pointer_cast(type); + return Decimal128Type::Make(Decimal128Type::kMaxPrecision, cast_type->scale()); + } else if (type->id() == Type::DECIMAL256) { + auto cast_type = checked_pointer_cast(type); + return Decimal256Type::Make(Decimal256Type::kMaxPrecision, cast_type->scale()); + } + return type; +} + bool HasDecimal(const std::vector& types) { for (const auto& th : types) { if (is_decimal(th.id())) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 29ab2c70e46c4..6eee2cac40eef 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -55,6 +55,7 @@ using internal::BinaryBitBlockCounter; using internal::BitBlockCount; using internal::BitmapReader; using internal::checked_cast; +using internal::checked_pointer_cast; using internal::FirstTimeBitmapWriter; using internal::GenerateBitsUnrolled; using internal::VisitBitBlocks; @@ -1382,6 +1383,8 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector ARROW_EXPORT Status CastDecimalArgs(TypeHolder* begin, size_t count); +Result> WidenDecimalToMaxPrecision(std::shared_ptr type); + ARROW_EXPORT bool HasDecimal(const std::vector& types);