Skip to content

Commit

Permalink
Replace explicit checking with DCHECK for invariants in row segmenter
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Sep 26, 2024
1 parent bc923bd commit d7c7e95
Showing 1 changed file with 32 additions and 49 deletions.
81 changes: 32 additions & 49 deletions cpp/src/arrow/compute/row/grouper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ template <typename Value>
ARROW_DEPRECATED("Deprecated in 18.0.0 along with GetSegments.")
Status CheckForGetNextSegment(const std::vector<Value>& values, int64_t length,
int64_t offset, const std::vector<TypeHolder>& key_types) {
if (offset < 0 || offset > length) {
return Status::Invalid("invalid grouping segmenter offset: ", offset);
}
DCHECK_GE(offset, 0);
DCHECK_LE(offset, length);
if (values.size() != key_types.size()) {
return Status::Invalid("expected batch size ", key_types.size(), " but got ",
values.size());
Expand Down Expand Up @@ -173,10 +172,8 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
Result<Segment> GetNextSegmentDeprecated(const Scalar& scalar, int64_t offset,
int64_t length) {
ARROW_SUPPRESS_DEPRECATION_WARNING
ARROW_RETURN_NOT_OK(CheckType(*scalar.type));
if (!scalar.is_valid) {
return Status::Invalid("segmenting an invalid scalar");
}
DCHECK(is_fixed_width(*scalar.type));
DCHECK(scalar.is_valid);
auto data = checked_cast<const PrimitiveScalarBase&>(scalar).data();
bool extends = length > 0 ? ExtendDeprecated(data) : kEmptyExtends;
return MakeSegment(length, offset, length, extends);
Expand All @@ -188,7 +185,7 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
const uint8_t* array_bytes, int64_t offset,
int64_t length) {
ARROW_SUPPRESS_DEPRECATION_WARNING
RETURN_NOT_OK(CheckType(array_type));
DCHECK(is_fixed_width(array_type));
DCHECK_LE(offset, length);
int64_t byte_width = array_type.byte_width();
int64_t match_length = GetMatchLength(array_bytes + offset * byte_width, byte_width,
Expand All @@ -211,9 +208,7 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
}
ARROW_DCHECK(value.is_array());
const auto& array = value.array;
if (array.GetNullCount() > 0) {
return Status::NotImplemented("segmenting a nullable array");
}
DCHECK_GT(array.GetNullCount(), 0);
return GetNextSegmentDeprecated(*array.type, GetValuesAsBytes(array), offset,
batch.length);
ARROW_UNSUPPRESS_DEPRECATION_WARNING
Expand All @@ -227,7 +222,7 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
}

const auto& value = batch.values[0];
RETURN_NOT_OK(CheckType(*value.type()));
DCHECK(is_fixed_width(*value.type()));

std::vector<Segment> segments;
const void* key_data;
Expand Down Expand Up @@ -261,13 +256,6 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
}

private:
static Status CheckType(const DataType& type) {
if (!is_fixed_width(type)) {
return Status::Invalid("SimpleKeySegmenter does not support type ", type);
}
return Status::OK();
}

static const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset = 0) {
DCHECK_GT(data.type->byte_width(), 0);
int64_t absolute_byte_offset = (data.offset + offset) * data.type->byte_width();
Expand Down Expand Up @@ -354,23 +342,20 @@ struct AnyKeysSegmenter : public BaseRowSegmenter {
ARROW_RETURN_NOT_OK(grouper_->Reset());

ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset));
if (datum.is_array()) {
// `data` is an array whose index-0 corresponds to index `offset` of `batch`
const std::shared_ptr<ArrayData>& data = datum.array();
DCHECK_EQ(data->length, batch.length - offset);
ARROW_DCHECK(data->GetNullCount() == 0);
DCHECK_EQ(data->type->id(), GroupIdType::type_id);
const group_id_t* values = data->GetValues<group_id_t>(1);
int64_t cursor;
for (cursor = 1; cursor < data->length; cursor++) {
if (values[0] != values[cursor]) break;
}
int64_t length = cursor;
bool extends = length > 0 ? bound_extend(values) : kEmptyExtends;
return MakeSegment(batch.length, offset, length, extends);
} else {
return Status::Invalid("segmenting unsupported datum kind ", datum.kind());
DCHECK(datum.is_array());
// `data` is an array whose index-0 corresponds to index `offset` of `batch`
const std::shared_ptr<ArrayData>& data = datum.array();
DCHECK_EQ(data->length, batch.length - offset);
ARROW_DCHECK(data->GetNullCount() == 0);
DCHECK_EQ(data->type->id(), GroupIdType::type_id);
const group_id_t* values = data->GetValues<group_id_t>(1);
int64_t cursor;
for (cursor = 1; cursor < data->length; cursor++) {
if (values[0] != values[cursor]) break;
}
int64_t length = cursor;
bool extends = length > 0 ? bound_extend(values) : kEmptyExtends;
return MakeSegment(batch.length, offset, length, extends);
ARROW_UNSUPPRESS_DEPRECATION_WARNING
}

Expand Down Expand Up @@ -432,9 +417,7 @@ struct AnyKeysSegmenter : public BaseRowSegmenter {
Result<group_id_t> MapGroupIdAt(const Batch& batch, int64_t offset = 0) {
ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset,
/*length=*/1));
if (!datum.is_array()) {
return Status::Invalid("accessing unsupported datum kind ", datum.kind());
}
DCHECK(datum.is_array());
const std::shared_ptr<ArrayData>& data = datum.array();
ARROW_DCHECK(data->GetNullCount() == 0);
DCHECK_EQ(data->type->id(), GroupIdType::type_id);
Expand All @@ -448,17 +431,6 @@ struct AnyKeysSegmenter : public BaseRowSegmenter {
group_id_t save_group_id_;
};

Status CheckAndCapLengthForConsume(int64_t batch_length, int64_t& consume_offset,
int64_t* consume_length) {
if (consume_offset < 0) {
return Status::Invalid("invalid grouper consume offset: ", consume_offset);
}
if (*consume_length < 0) {
*consume_length = batch_length - consume_offset;
}
return Status::OK();
}

} // namespace

Result<std::unique_ptr<RowSegmenter>> MakeAnyKeysSegmenter(
Expand All @@ -481,6 +453,17 @@ Result<std::unique_ptr<RowSegmenter>> RowSegmenter::Make(

namespace {

Status CheckAndCapLengthForConsume(int64_t batch_length, int64_t& consume_offset,
int64_t* consume_length) {
if (consume_offset < 0) {
return Status::Invalid("invalid grouper consume offset: ", consume_offset);
}
if (*consume_length < 0) {
*consume_length = batch_length - consume_offset;
}
return Status::OK();
}

struct GrouperImpl : public Grouper {
static Result<std::unique_ptr<GrouperImpl>> Make(
const std::vector<TypeHolder>& key_types, ExecContext* ctx) {
Expand Down

0 comments on commit d7c7e95

Please sign in to comment.