Skip to content

Commit

Permalink
[ntuple] add unit test for in-memory union merging
Browse files Browse the repository at this point in the history
  • Loading branch information
silverweed committed Feb 6, 2025
1 parent 6c034e0 commit 9f4458c
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions tree/ntuple/v7/test/ntuple_merger.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1718,3 +1718,111 @@ TEST(RNTupleMerger, MergeIncrementalLMExt)
}
}
}

TEST(RNTupleMerger, MergeIncrementalLMExtMemFile)
{
// Same as MergeIncrementalLMExt but using TMemFiles
std::vector<std::unique_ptr<TMemFile>> inputFiles;
const auto nInputs = 12;
auto model = RNTupleModel::Create();
for (int fileIdx = 0; fileIdx < nInputs; ++fileIdx) {
auto &file =
inputFiles.emplace_back(new TMemFile((std::string("memFile_") + std::to_string(fileIdx)).c_str(), "CREATE"));

// Each input gets a different model, so we can exercise the late model extension.
// Just to have some variation, use different types depending on the field index
const auto fieldName = std::string("f_") + std::to_string(fileIdx);
switch (fileIdx % 3) {
case 0: model->MakeField<int>(fieldName); break;
case 1: model->MakeField<float>(fieldName); break;
default: model->MakeField<std::string>(fieldName);
}

auto writer = RNTupleWriter::Append(model->Clone(), "ntpl", *file);

// Fill the RNTuple with nFills per field
const auto nFills = 5;
const auto &entry = writer->GetModel().GetDefaultEntry();
for (int fillIdx = 0; fillIdx < nFills; ++fillIdx) {
for (int fieldIdx = 0; fieldIdx < fileIdx + 1; ++fieldIdx) {
const auto fldName = std::string("f_") + std::to_string(fieldIdx);
switch (fieldIdx % 3) {
case 0: *entry.GetPtr<int>(fldName) = fileIdx + fillIdx + fieldIdx; break;
case 1: *entry.GetPtr<float>(fldName) = fileIdx + fillIdx + fieldIdx; break;
default: *entry.GetPtr<std::string>(fldName) = std::to_string(fileIdx + fillIdx + fieldIdx);
}
}
writer->Fill();
}
}

// Incrementally merge the inputs
FileRaii fileGuard("test_ntuple_merge_incr_lmext_memfile.root");
const auto compression = 505;

{
TFileMerger merger(kFALSE, kFALSE);
merger.OutputFile(fileGuard.GetPath().c_str(), "RECREATE", compression);
merger.SetMergeOptions(TString("rntuple.MergingMode=Union"));

for (int i = 0; i < nInputs; ++i) {
merger.AddFile(inputFiles[i].get());
bool result =
merger.PartialMerge(TFileMerger::kIncremental | TFileMerger::kNonResetable | TFileMerger::kKeepCompression);
ASSERT_TRUE(result);
}
}

// Now verify that the output file contains all the expected data.
{
auto reader = RNTupleReader::Open("ntpl", fileGuard.GetPath());
const auto &desc = reader->GetDescriptor();
for (int i = 0; i < nInputs; ++i) {
const auto fieldId = desc.FindFieldId(std::string("f_") + std::to_string(i));
EXPECT_NE(fieldId, ROOT::Experimental::kInvalidDescriptorId);
const auto &fdesc = desc.GetFieldDescriptor(fieldId);
for (const auto &colId : fdesc.GetLogicalColumnIds()) {
const auto &cdesc = desc.GetColumnDescriptor(colId);
EXPECT_EQ(cdesc.GetFirstElementIndex(), (cdesc.GetIndex() == 0) * i * 5);
}
}

RNTupleView<int> v_int[] = {
reader->GetView<int>("f_0"),
reader->GetView<int>("f_3"),
reader->GetView<int>("f_6"),
reader->GetView<int>("f_9"),
};
RNTupleView<float> v_float[] = {
reader->GetView<float>("f_1"),
reader->GetView<float>("f_4"),
reader->GetView<float>("f_7"),
reader->GetView<float>("f_10"),
};
RNTupleView<std::string> v_string[] = {
reader->GetView<std::string>("f_2"),
reader->GetView<std::string>("f_5"),
reader->GetView<std::string>("f_8"),
reader->GetView<std::string>("f_11"),
};
for (auto entryId : reader->GetEntryRange()) {
int fileIdx = entryId / 5;
int localEntryId = entryId % 5;

for (int i = 0; i < nInputs / 3; ++i) {
auto x0 = v_int[i](entryId);
int expected_x0 = (entryId >= 15u * i) * (fileIdx + localEntryId + i * 3);
EXPECT_EQ(x0, expected_x0);

auto x1 = v_float[i](entryId);
float expected_x1 = (entryId >= 5 + 15u * i) * (fileIdx + localEntryId + i * 3 + 1);
EXPECT_FLOAT_EQ(x1, expected_x1);

auto x2 = v_string[i](entryId);
std::string expected_x2 =
(entryId >= 10 + 15u * i) ? std::to_string(fileIdx + localEntryId + i * 3 + 2) : "";
EXPECT_EQ(x2, expected_x2);
}
}
}
}

0 comments on commit 9f4458c

Please sign in to comment.