From 72bb26c2270faaf3046105637b5afd0e55fa6c71 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Tue, 24 Sep 2024 13:45:29 +0200 Subject: [PATCH] Stricter equality check --- cpp/src/arrow/extension/json.cc | 25 ++++++++++++++----------- cpp/src/arrow/extension/json.h | 4 +++- cpp/src/arrow/extension/json_test.cc | 14 ++++++++++++++ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/extension/json.cc b/cpp/src/arrow/extension/json.cc index d793233c2b573..5e396738e62df 100644 --- a/cpp/src/arrow/extension/json.cc +++ b/cpp/src/arrow/extension/json.cc @@ -28,17 +28,13 @@ namespace arrow::extension { bool JsonExtensionType::ExtensionEquals(const ExtensionType& other) const { - return other.extension_name() == this->extension_name(); + return other.extension_name() == this->extension_name() && + other.storage_type()->Equals(storage_type_); } Result> JsonExtensionType::Deserialize( std::shared_ptr storage_type, const std::string& serialized) const { - if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW && - storage_type->id() != Type::LARGE_STRING) { - return Status::Invalid("Invalid storage type for JsonExtensionType: ", - storage_type->ToString()); - } - return std::make_shared(storage_type); + return JsonExtensionType::Make(storage_type); } std::string JsonExtensionType::Serialize() const { return ""; } @@ -51,11 +47,18 @@ std::shared_ptr JsonExtensionType::MakeArray( return std::make_shared(data); } -std::shared_ptr json(const std::shared_ptr storage_type) { - ARROW_CHECK(storage_type->id() != Type::STRING || - storage_type->id() != Type::STRING_VIEW || - storage_type->id() != Type::LARGE_STRING); +Result> JsonExtensionType::Make( + const std::shared_ptr& storage_type) { + if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW && + storage_type->id() != Type::LARGE_STRING) { + return Status::Invalid("Invalid storage type for JsonExtensionType: ", + storage_type->ToString()); + } return std::make_shared(storage_type); } +std::shared_ptr json(const std::shared_ptr& storage_type) { + return JsonExtensionType::Make(storage_type).ValueOrDie(); +} + } // namespace arrow::extension diff --git a/cpp/src/arrow/extension/json.h b/cpp/src/arrow/extension/json.h index 4793ab2bc9b36..4d475536cff59 100644 --- a/cpp/src/arrow/extension/json.h +++ b/cpp/src/arrow/extension/json.h @@ -45,12 +45,14 @@ class ARROW_EXPORT JsonExtensionType : public ExtensionType { std::shared_ptr MakeArray(std::shared_ptr data) const override; + static Result> Make(const std::shared_ptr& storage_type); + private: std::shared_ptr storage_type_; }; /// \brief Return a JsonExtensionType instance. ARROW_EXPORT std::shared_ptr json( - std::shared_ptr storage_type = utf8()); + const std::shared_ptr& storage_type = utf8()); } // namespace arrow::extension diff --git a/cpp/src/arrow/extension/json_test.cc b/cpp/src/arrow/extension/json_test.cc index 143e4f9ceeac7..b938ddb2cfef3 100644 --- a/cpp/src/arrow/extension/json_test.cc +++ b/cpp/src/arrow/extension/json_test.cc @@ -80,4 +80,18 @@ TEST_F(TestJsonExtensionType, InvalidUTF8) { } } +TEST_F(TestJsonExtensionType, StorageTypeValidation) { + ASSERT_TRUE(json(utf8())->Equals(json(utf8()))); + ASSERT_FALSE(json(large_utf8())->Equals(json(utf8()))); + ASSERT_FALSE(json(utf8_view())->Equals(json(utf8()))); + ASSERT_FALSE(json(utf8_view())->Equals(json(large_utf8()))); + + for (const auto& storage_type : {int16(), binary(), float64(), null()}) { + ASSERT_RAISES_WITH_MESSAGE(Invalid, + "Invalid: Invalid storage type for JsonExtensionType: " + + storage_type->ToString(), + extension::JsonExtensionType::Make(storage_type)); + } +} + } // namespace arrow