Skip to content

Commit

Permalink
extract common code
Browse files Browse the repository at this point in the history
  • Loading branch information
khwilson committed Sep 30, 2024
1 parent ac05e70 commit 1339df5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 30 deletions.
12 changes: 1 addition & 11 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,17 +366,7 @@ struct ProductImpl : public ScalarAggregator {

Status Finalize(KernelContext*, Datum* out) override {
std::shared_ptr<DataType> out_type_;
if (out_type->id() == Type::DECIMAL128) {
auto cast_type = checked_pointer_cast<Decimal128Type>(this->out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision,
cast_type->scale()));
} else if (out_type->id() == Type::DECIMAL256) {
auto cast_type = checked_pointer_cast<Decimal256Type>(this->out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision,
cast_type->scale()));
} else {
out_type_ = out_type;
}
ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type));

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
Expand Down
21 changes: 2 additions & 19 deletions cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,7 @@ struct SumImpl : public ScalarAggregator {

Status Finalize(KernelContext*, Datum* out) override {
std::shared_ptr<DataType> out_type_;
if (out_type->id() == Type::DECIMAL128) {
auto cast_type = checked_pointer_cast<Decimal128Type>(this->out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision,
cast_type->scale()));
} else if (out_type->id() == Type::DECIMAL256) {
auto cast_type = checked_pointer_cast<Decimal256Type>(this->out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision,
cast_type->scale()));
} else {
out_type_ = out_type;
}
ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type));

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
Expand Down Expand Up @@ -233,14 +223,7 @@ struct MeanImpl<ArrowType, SimdLevel, enable_if_decimal<ArrowType>>
template <typename T = ArrowType>
Status FinalizeImpl(Datum* out) {
std::shared_ptr<DataType> out_type_;
auto decimal_type = checked_pointer_cast<DecimalType>(this->out_type);
if (decimal_type->id() == Type::DECIMAL128) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision,
decimal_type->scale()));
} else {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision,
decimal_type->scale()));
}
ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type));

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count) || (this->count == 0)) {
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,17 @@ Status CastDecimalArgs(TypeHolder* begin, size_t count) {
return Status::OK();
}

Result<std::shared_ptr<DataType>> WidenDecimalToMaxPrecision(std::shared_ptr<DataType> type) {
if (type->id() == Type::DECIMAL128) {
auto cast_type = checked_pointer_cast<Decimal128Type>(type);
return Decimal128Type::Make(Decimal128Type::kMaxPrecision, cast_type->scale());
} else if (type->id() == Type::DECIMAL256) {
auto cast_type = checked_pointer_cast<Decimal256Type>(type);
return Decimal256Type::Make(Decimal256Type::kMaxPrecision, cast_type->scale());
}
return type;
}

bool HasDecimal(const std::vector<TypeHolder>& types) {
for (const auto& th : types) {
if (is_decimal(th.id())) {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ using internal::BinaryBitBlockCounter;
using internal::BitBlockCount;
using internal::BitmapReader;
using internal::checked_cast;
using internal::checked_pointer_cast;
using internal::FirstTimeBitmapWriter;
using internal::GenerateBitsUnrolled;
using internal::VisitBitBlocks;
Expand Down Expand Up @@ -1382,6 +1383,8 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector<TypeHolder>
ARROW_EXPORT
Status CastDecimalArgs(TypeHolder* begin, size_t count);

Result<std::shared_ptr<DataType>> WidenDecimalToMaxPrecision(std::shared_ptr<DataType> type);

ARROW_EXPORT
bool HasDecimal(const std::vector<TypeHolder>& types);

Expand Down

0 comments on commit 1339df5

Please sign in to comment.