Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions cpp/src/arrow/array/diff.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ struct UnitSlice {
bool operator!=(const UnitSlice& other) const { return !(*this == other); }
};

// FIXME(bkietz) this is inefficient;
// StructArray's fields can be diffed independently then merged
UnitSlice GetView(const StructArray& array, int64_t index) {
return UnitSlice{&array, index};
}

UnitSlice GetView(const UnionArray& array, int64_t index) {
return UnitSlice{&array, index};
}
Expand Down Expand Up @@ -164,6 +158,45 @@ struct DefaultValueComparator : public ValueComparator {
}
};

class StructValueComparator : public ValueComparator {
private:
const StructArray& base_;
const StructArray& target_;
std::vector<std::unique_ptr<ValueComparator>> field_comparators_;

public:
StructValueComparator(const StructArray& base, const StructArray& target,
std::vector<std::unique_ptr<ValueComparator>>&& field_comparators)
: base_(base), target_(target), field_comparators_(std::move(field_comparators)) {
DCHECK_EQ(*base_.type(), *target_.type());
DCHECK_EQ(base_.num_fields(), static_cast<int>(field_comparators_.size()));
}

~StructValueComparator() override = default;

bool Equals(int64_t base_index, int64_t target_index) override {
const bool base_valid = base_.IsValid(base_index);
const bool target_valid = target_.IsValid(target_index);

if (base_valid != target_valid) {
return false;
}

if (!base_valid) {
return true; // Both null
}

// Compare each field independently with early termination
for (const auto& field_comparator : field_comparators_) {
if (!field_comparator->Equals(base_index, target_index)) {
return false;
}
}

return true;
}
};

template <typename RunEndCType>
class REEValueComparator : public ValueComparator {
private:
Expand Down Expand Up @@ -308,6 +341,26 @@ class ValueComparatorFactory {
return Status::NotImplemented("dictionary type");
}

Status Visit(const StructType& struct_type, const Array& base, const Array& target) {
const auto& base_struct = checked_cast<const StructArray&>(base);
const auto& target_struct = checked_cast<const StructArray&>(target);

// Create comparators for each field
std::vector<std::unique_ptr<ValueComparator>> field_comparators;
field_comparators.reserve(struct_type.num_fields());

for (int i = 0; i < struct_type.num_fields(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto field_comparator,
Create(*struct_type.field(i)->type(), *base_struct.field(i),
*target_struct.field(i)));
field_comparators.push_back(std::move(field_comparator));
}

comparator_ = std::make_unique<StructValueComparator>(base_struct, target_struct,
std::move(field_comparators));
return Status::OK();
}

Status Visit(const RunEndEncodedType& ree_type, const Array& base,
const Array& target) {
const auto& base_ree = checked_cast<const RunEndEncodedArray&>(base);
Expand Down
77 changes: 77 additions & 0 deletions cpp/src/arrow/array/diff_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,83 @@ TEST_F(DiffTest, CompareRandomStruct) {
}
}

TEST_F(DiffTest, StructFieldComparison) {
// test struct field-by-field comparison
auto type = struct_(
{field("first", int32()), field("second", utf8()), field("third", int64())});

// first field differs
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
target_ = ArrayFromJSON(type, R"([{"first": 2, "second": "a", "third": 100}])");
DoDiff();
AssertInsertIs("[false, false, true]");
AssertRunLengthIs("[0, 0, 0]");

// second field differs
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "b", "third": 100}])");
DoDiff();
AssertInsertIs("[false, false, true]");
AssertRunLengthIs("[0, 0, 0]");

// third field differs
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 200}])");
DoDiff();
AssertInsertIs("[false, false, true]");
AssertRunLengthIs("[0, 0, 0]");

// all fields equal
base_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
target_ = ArrayFromJSON(type, R"([{"first": 1, "second": "a", "third": 100}])");
DoDiff();
AssertInsertIs("[false]");
AssertRunLengthIs("[1]");
}

TEST_F(DiffTest, NestedStructComparison) {
// test nested struct comparison
auto inner_type = struct_({field("x", int32()), field("y", int32())});
auto outer_type =
struct_({field("id", int32()), field("inner", inner_type), field("name", utf8())});

// outer first field differs
base_ = ArrayFromJSON(outer_type,
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
target_ = ArrayFromJSON(outer_type,
R"([{"id": 2, "inner": {"x": 10, "y": 20}, "name": "test"}])");
DoDiff();
AssertInsertIs("[false, false, true]");
AssertRunLengthIs("[0, 0, 0]");

// nested struct first field differs
base_ = ArrayFromJSON(outer_type,
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
target_ = ArrayFromJSON(outer_type,
R"([{"id": 1, "inner": {"x": 99, "y": 20}, "name": "test"}])");
DoDiff();
AssertInsertIs("[false, false, true]");
AssertRunLengthIs("[0, 0, 0]");

// nested struct second field differs
base_ = ArrayFromJSON(outer_type,
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
target_ = ArrayFromJSON(outer_type,
R"([{"id": 1, "inner": {"x": 10, "y": 99}, "name": "test"}])");
DoDiff();
AssertInsertIs("[false, false, true]");
AssertRunLengthIs("[0, 0, 0]");

// all equal including nested struct
base_ = ArrayFromJSON(outer_type,
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
target_ = ArrayFromJSON(outer_type,
R"([{"id": 1, "inner": {"x": 10, "y": 20}, "name": "test"}])");
DoDiff();
AssertInsertIs("[false]");
AssertRunLengthIs("[1]");
}

TEST_F(DiffTest, CompareHalfFloat) {
auto first = ArrayFromJSON(float16(), "[1.1, 2.0, 2.5, 3.3]");
auto second = ArrayFromJSON(float16(), "[1.1, 4.0, 3.5, 3.3]");
Expand Down
Loading