diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc index fd907e3c7b2..050fe65b064 100644 --- a/cpp/src/arrow/array/diff.cc +++ b/cpp/src/arrow/array/diff.cc @@ -93,12 +93,6 @@ struct UnitSlice { bool operator!=(const UnitSlice& other) const { return !(*this == other); } }; -// FIXME(bkietz) this is inefficient; -// StructArray's fields can be diffed independently then merged -UnitSlice GetView(const StructArray& array, int64_t index) { - return UnitSlice{&array, index}; -} - UnitSlice GetView(const UnionArray& array, int64_t index) { return UnitSlice{&array, index}; } @@ -164,6 +158,45 @@ struct DefaultValueComparator : public ValueComparator { } }; +class StructValueComparator : public ValueComparator { + private: + const StructArray& base_; + const StructArray& target_; + std::vector> field_comparators_; + + public: + StructValueComparator(const StructArray& base, const StructArray& target, + std::vector>&& field_comparators) + : base_(base), target_(target), field_comparators_(std::move(field_comparators)) { + DCHECK_EQ(*base_.type(), *target_.type()); + DCHECK_EQ(base_.num_fields(), static_cast(field_comparators_.size())); + } + + ~StructValueComparator() override = default; + + bool Equals(int64_t base_index, int64_t target_index) override { + const bool base_valid = base_.IsValid(base_index); + const bool target_valid = target_.IsValid(target_index); + + if (base_valid != target_valid) { + return false; + } + + if (!base_valid) { + return true; // Both null + } + + // Compare each field independently with early termination + for (const auto& field_comparator : field_comparators_) { + if (!field_comparator->Equals(base_index, target_index)) { + return false; + } + } + + return true; + } +}; + template class REEValueComparator : public ValueComparator { private: @@ -308,6 +341,26 @@ class ValueComparatorFactory { return Status::NotImplemented("dictionary type"); } + Status Visit(const StructType& struct_type, const Array& base, const Array& target) { + const auto& base_struct = checked_cast(base); + const auto& target_struct = checked_cast(target); + + // Create comparators for each field + std::vector> field_comparators; + field_comparators.reserve(struct_type.num_fields()); + + for (int i = 0; i < struct_type.num_fields(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto field_comparator, + Create(*struct_type.field(i)->type(), *base_struct.field(i), + *target_struct.field(i))); + field_comparators.push_back(std::move(field_comparator)); + } + + comparator_ = std::make_unique(base_struct, target_struct, + std::move(field_comparators)); + return Status::OK(); + } + Status Visit(const RunEndEncodedType& ree_type, const Array& base, const Array& target) { const auto& base_ree = checked_cast(base); diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc index 76f4202992f..6424e3eef3d 100644 --- a/cpp/src/arrow/array/diff_test.cc +++ b/cpp/src/arrow/array/diff_test.cc @@ -816,6 +816,83 @@ TEST_F(DiffTest, CompareRandomStruct) { } } +TEST_F(DiffTest, StructFieldComparison) { + // test struct field-by-field comparison + auto type = struct_( + {field("first", int32()), field("second", utf8()), field("third", int64())}); + + // first field differs + base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])"); + target_ = ArrayFromJSON(type, R"([{"first": 2, "second": "a", "third": 100}])"); + DoDiff(); + AssertInsertIs("[false, false, true]"); + AssertRunLengthIs("[0, 0, 0]"); + + // second field differs + base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])"); + target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "b", "third": 100}])"); + DoDiff(); + AssertInsertIs("[false, false, true]"); + AssertRunLengthIs("[0, 0, 0]"); + + // third field differs + base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])"); + target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 200}])"); + DoDiff(); + AssertInsertIs("[false, false, true]"); + AssertRunLengthIs("[0, 0, 0]"); + + // all fields equal + base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])"); + target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])"); + DoDiff(); + AssertInsertIs("[false]"); + AssertRunLengthIs("[1]"); +} + +TEST_F(DiffTest, NestedStructComparison) { + // test nested struct comparison + auto inner_type = struct_({field("x", int32()), field("y", int32())}); + auto outer_type = + struct_({field("id", int32()), field("inner", inner_type), field("name", utf8())}); + + // outer first field differs + base_ = ArrayFromJSON(outer_type, + R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])"); + target_ = ArrayFromJSON(outer_type, + R"([{"id": 2, "inner": {"x": 10, "y": 20}, "name": "test"}])"); + DoDiff(); + AssertInsertIs("[false, false, true]"); + AssertRunLengthIs("[0, 0, 0]"); + + // nested struct first field differs + base_ = ArrayFromJSON(outer_type, + R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])"); + target_ = ArrayFromJSON(outer_type, + R"([{"id": 1, "inner": {"x": 99, "y": 20}, "name": "test"}])"); + DoDiff(); + AssertInsertIs("[false, false, true]"); + AssertRunLengthIs("[0, 0, 0]"); + + // nested struct second field differs + base_ = ArrayFromJSON(outer_type, + R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])"); + target_ = ArrayFromJSON(outer_type, + R"([{"id": 1, "inner": {"x": 10, "y": 99}, "name": "test"}])"); + DoDiff(); + AssertInsertIs("[false, false, true]"); + AssertRunLengthIs("[0, 0, 0]"); + + // all equal including nested struct + base_ = ArrayFromJSON(outer_type, + R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])"); + target_ = ArrayFromJSON(outer_type, + R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])"); + DoDiff(); + AssertInsertIs("[false]"); + AssertRunLengthIs("[1]"); +} + TEST_F(DiffTest, CompareHalfFloat) { auto first = ArrayFromJSON(float16(), "[1.1, 2.0, 2.5, 3.3]"); auto second = ArrayFromJSON(float16(), "[1.1, 4.0, 3.5, 3.3]");