diff --git a/Framework/Core/include/Framework/RootArrowFilesystem.h b/Framework/Core/include/Framework/RootArrowFilesystem.h index df00ce4fa8a76..7c8385ccd2b9d 100644 --- a/Framework/Core/include/Framework/RootArrowFilesystem.h +++ b/Framework/Core/include/Framework/RootArrowFilesystem.h @@ -87,6 +87,11 @@ class TTreeFileSystem : public VirtualRootFileSystemBase { return std::dynamic_pointer_cast(shared_from_this()); }; + + arrow::Result> OpenOutputStream( + const std::string& path, + const std::shared_ptr& metadata) override; + virtual TTree* GetTree(arrow::dataset::FileSource source) = 0; }; @@ -128,6 +133,10 @@ class TFileFileSystem : public VirtualRootFileSystemBase std::shared_ptr GetSubFilesystem(arrow::dataset::FileSource source) override; + arrow::Result> OpenOutputStream( + const std::string& path, + const std::shared_ptr& metadata) override; + // We can go back to the TFile in case this is needed. TDirectoryFile* GetFile() { @@ -218,6 +227,29 @@ class TTreeFileFormat : public arrow::dataset::FileFormat const std::shared_ptr& fragment) const override; }; +// An arrow outputstream which allows to write to a ttree +class TTreeOutputStream : public arrow::io::OutputStream +{ + public: + TTreeOutputStream(TTree* t); + + arrow::Status Close() override; + + arrow::Result Tell() const override; + + arrow::Status Write(const void* data, int64_t nbytes) override; + + bool closed() const override; + + TTree* GetTree() + { + return mTree; + } + + private: + TTree* mTree; +}; + } // namespace o2::framework #endif // O2_FRAMEWORK_ROOT_ARROW_FILESYSTEM_H_ diff --git a/Framework/Core/src/RootArrowFilesystem.cxx b/Framework/Core/src/RootArrowFilesystem.cxx index 46489141c3173..7581ee57e5b9f 100644 --- a/Framework/Core/src/RootArrowFilesystem.cxx +++ b/Framework/Core/src/RootArrowFilesystem.cxx @@ -11,6 +11,7 @@ #include "Framework/RootArrowFilesystem.h" #include "Framework/Endian.h" #include "Framework/RuntimeError.h" +#include "Framework/Signpost.h" #include #include #include @@ -24,6 +25,13 @@ #include #include #include +#include +#include +#include +#include + + +O2_DECLARE_DYNAMIC_LOG(root_arrow_fs); namespace { @@ -76,6 +84,7 @@ auto arrowTypeFromROOT(EDataType type, int size) } namespace o2::framework { +using arrow::Status; TFileFileSystem::TFileFileSystem(TDirectoryFile* f, size_t readahead) : VirtualRootFileSystemBase(), @@ -116,6 +125,15 @@ arrow::Result TFileFileSystem::GetFileInfo(const std::strin return result; } +arrow::Result> TFileFileSystem::OpenOutputStream( + const std::string& path, + const std::shared_ptr& metadata) +{ + auto* t = new TTree(path.c_str(), "should put a name here"); + auto stream = std::make_shared(t); + return stream; +} + arrow::Result VirtualRootFileSystemBase::GetFileInfo(std::string const&) { arrow::fs::FileInfo result; @@ -267,9 +285,279 @@ arrow::Result> TTreeFileFormat::Ma return std::dynamic_pointer_cast(fragment); } +// An arrow outputstream which allows to write to a ttree +TTreeOutputStream::TTreeOutputStream(TTree* t) + : mTree(t) +{ +} + +arrow::Status TTreeOutputStream::Close() +{ + mTree->GetCurrentFile()->Close(); + return arrow::Status::OK(); +} + +arrow::Result TTreeOutputStream::Tell() const +{ + return arrow::Result(arrow::Status::NotImplemented("Cannot move")); +} + +arrow::Status TTreeOutputStream::Write(const void* data, int64_t nbytes) +{ + return arrow::Status::NotImplemented("Cannot write raw bytes to a TTree"); +} + +bool TTreeOutputStream::closed() const +{ + return mTree->GetCurrentFile()->IsOpen() == false; +} + +char const* rootSuffixFromArrow(arrow::Type::type id) +{ + switch (id) { + case arrow::Type::BOOL: + return "/O"; + case arrow::Type::UINT8: + return "/b"; + case arrow::Type::UINT16: + return "/s"; + case arrow::Type::UINT32: + return "/i"; + case arrow::Type::UINT64: + return "/l"; + case arrow::Type::INT8: + return "/B"; + case arrow::Type::INT16: + return "/S"; + case arrow::Type::INT32: + return "/I"; + case arrow::Type::INT64: + return "/L"; + case arrow::Type::FLOAT: + return "/F"; + case arrow::Type::DOUBLE: + return "/D"; + default: + throw runtime_error("Unsupported arrow column type"); + } +} + +class TTreeFileWriter : public arrow::dataset::FileWriter +{ + std::vector branches; + std::vector sizesBranches; + std::vector> valueArrays; + std::vector> sizeArrays; + std::vector> valueTypes; + + std::vector valuesIdealBasketSize; + std::vector sizeIdealBasketSize; + + std::vector typeSizes; + std::vector listSizes; + bool firstBasket = true; + + // This is to create a batsket size according to the first batch. + void finaliseBasketSize(std::shared_ptr firstBatch) + { + O2_SIGNPOST_ID_FROM_POINTER(sid, root_arrow_fs, this); + O2_SIGNPOST_START(root_arrow_fs, sid, "finaliseBasketSize", "First batch with %lli rows received and %zu columns", + firstBatch->num_rows(), firstBatch->columns().size()); + for (size_t i = 0; i < branches.size(); i++) { + auto* branch = branches[i]; + auto* sizeBranch = sizesBranches[i]; + + int valueSize = valueTypes[i]->byte_width(); + if (listSizes[i] == 1) { + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry for %lli entries.", + branch->GetName(), valueSize, firstBatch->num_rows()); + assert(sizeBranch == nullptr); + branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize); + } else if (listSizes[i] == -1) { + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry.", + branch->GetName(), valueSize); + // This should probably lookup the + auto column = firstBatch->GetColumnByName(branch->GetName()); + auto list = std::static_pointer_cast(column); + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. Associated size branch %s and there are %lli entries of size %d in that list.", + branch->GetName(), sizeBranch->GetName(), list->length(), valueSize); + branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize * list->length()); + sizeBranch->SetBasketSize(1024 + firstBatch->num_rows() * 4); + } else { + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. There are %lli entries per array of size %d in that list.", + branch->GetName(), listSizes[i], valueSize); + assert(sizeBranch == nullptr); + branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize * listSizes[i]); + } + + auto field = firstBatch->schema()->field(i); + if (field->name().starts_with("fIndexArray")) { + // One int per array to keep track of the size + int idealBasketSize = 4 * firstBatch->num_rows() + 1024 + field->type()->byte_width() * firstBatch->num_rows(); // minimal additional size needed, otherwise we get 2 baskets + int basketSize = std::max(32000, idealBasketSize); // keep a minimum value + sizeBranch->SetBasketSize(basketSize); + branch->SetBasketSize(basketSize); + } + } + O2_SIGNPOST_END(root_arrow_fs, sid, "finaliseBasketSize", "Done"); + } + + public: + // Create the TTree based on the physical_schema, not the one in the batch. + // The write method will have to reconcile the two schemas. + TTreeFileWriter(std::shared_ptr schema, std::shared_ptr options, + std::shared_ptr destination, + arrow::fs::FileLocator destination_locator) + : FileWriter(schema, options, destination, destination_locator) + { + // Batches have the same number of entries for each column. + auto treeStream = std::dynamic_pointer_cast(destination_); + TTree* tree = treeStream->GetTree(); + + for (auto i = 0u; i < schema->fields().size(); ++i) { + auto& field = schema->field(i); + listSizes.push_back(1); + + int valuesIdealBasketSize = 0; + // Construct all the needed branches. + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + listSizes.back() = std::static_pointer_cast(field->type())->list_size(); + valuesIdealBasketSize = 1024 + valueTypes.back()->byte_width() * listSizes.back(); + valueTypes.push_back(field->type()->field(0)->type()); + sizesBranches.push_back(nullptr); + std::string leafList = fmt::format("{}[{}]{}", field->name(), listSizes.back(), rootSuffixFromArrow(valueTypes.back()->id())); + branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str())); + } break; + case arrow::Type::LIST: { + valueTypes.push_back(field->type()->field(0)->type()); + listSizes.back() = 0; // VLA, we need to calculate it on the fly; + std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id())); + std::string sizeLeafList = field->name() + "_size/I"; + sizesBranches.push_back(tree->Branch((field->name() + "_size").c_str(), (char*)nullptr, sizeLeafList.c_str())); + branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str())); + // Notice that this could be replaced by a better guess of the + // average size of the list elements, but this is not trivial. + } break; + default: { + valueTypes.push_back(field->type()); + std::string leafList = field->name() + rootSuffixFromArrow(valueTypes.back()->id()); + sizesBranches.push_back(nullptr); + branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str())); + } break; + } + } + // We create the branches from the schema + } + + arrow::Status Write(const std::shared_ptr& batch) override + { + if (firstBasket) { + firstBasket = false; + finaliseBasketSize(batch); + } + + // Support writing empty tables + if (batch->columns().empty() || batch->num_rows() == 0) { + return arrow::Status::OK(); + } + + // Batches have the same number of entries for each column. + auto treeStream = std::dynamic_pointer_cast(destination_); + TTree* tree = treeStream->GetTree(); + + // Caches for the vectors of bools. + std::vector> caches; + + for (auto i = 0u; i < batch->columns().size(); ++i) { + auto column = batch->column(i); + auto& field = batch->schema()->field(i); + + valueArrays.push_back(nullptr); + + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + auto list = std::static_pointer_cast(column); + valueArrays.back() = list->values(); + } break; + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(column); + valueArrays.back() = list; + } break; + default: + valueArrays.back() = column; + } + } + + int64_t pos = 0; + while (pos < batch->num_rows()) { + for (size_t bi = 0; bi < branches.size(); ++bi) { + auto* branch = branches[bi]; + auto* sizeBranch = sizesBranches[bi]; + auto array = batch->column(bi); + auto& field = batch->schema()->field(bi); + auto& listSize = listSizes[bi]; + auto valueType = valueTypes[bi]; + auto valueArray = valueArrays[bi]; + + if (field->type()->id() == arrow::Type::BOOL) { + auto boolArray = std::static_pointer_cast(array); + int64_t length = boolArray->length(); + arrow::UInt8Builder builder; + auto ok = builder.Reserve(length); + + for (int64_t i = 0; i < length; ++i) { + if (boolArray->IsValid(i)) { + // Expand each boolean value (true/false) to uint8 (1/0) + uint8_t value = boolArray->Value(i) ? 1 : 0; + auto ok = builder.Append(value); + } else { + // Append null for invalid entries + auto ok = builder.AppendNull(); + } + } + + ok = builder.Finish(&caches[bi]); + branch->SetAddress((void*)(caches[bi]->values()->data())); + continue; + } + switch (field->type()->id()) { + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(array); + listSize = list->value_length(pos); + uint8_t const* buffer = std::static_pointer_cast(valueArray)->values()->data() + array->offset() + list->value_offset(pos) * valueType->byte_width(); + branch->SetAddress((void*)buffer); + sizeBranch->SetAddress(&listSize); + }; + break; + case arrow::Type::FIXED_SIZE_LIST: + default: { + uint8_t const* buffer = std::static_pointer_cast(valueArray)->values()->data() + array->offset() + pos * listSize * valueType->byte_width(); + branch->SetAddress((void*)buffer); + }; + } + } + tree->Fill(); + ++pos; + } + return arrow::Status::OK(); + } + + arrow::Future<> FinishInternal() override + { + auto treeStream = std::dynamic_pointer_cast(destination_); + TTree* tree = treeStream->GetTree(); + tree->Write("", TObject::kOverwrite); + tree->SetDirectory(nullptr); + + return {}; + }; +}; + arrow::Result> TTreeFileFormat::MakeWriter(std::shared_ptr destination, std::shared_ptr schema, std::shared_ptr options, arrow::fs::FileLocator destination_locator) const { - throw std::runtime_error("Unsupported operation"); + auto writer = std::make_shared(schema, options, destination, destination_locator); + return std::dynamic_pointer_cast(writer); } std::shared_ptr TTreeFileFormat::DefaultWriteOptions() @@ -401,8 +689,10 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( int64_t listSize = 1; if (auto fixedSizeList = std::dynamic_pointer_cast(physicalField->type())) { listSize = fixedSizeList->list_size(); + typeSize = fixedSizeList->field(0)->type()->byte_width(); } else if (auto vlaListType = std::dynamic_pointer_cast(physicalField->type())) { listSize = -1; + typeSize = fixedSizeList->field(0)->type()->byte_width(); } if (listSize == -1) { mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str()); @@ -474,6 +764,15 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( return generator; } + +arrow::Result> TTreeFileSystem::OpenOutputStream( + const std::string& path, + const std::shared_ptr& metadata) +{ + auto stream = std::make_shared(GetTree({path, shared_from_this()})); + return stream; +} + TBufferFileFS::TBufferFileFS(TBufferFile* f) : VirtualRootFileSystemBase(), mBuffer(f), @@ -512,5 +811,4 @@ std::shared_ptr TBufferFileFS::GetSubFilesystem(arrow } return mFilesystem; } - } // namespace o2::framework diff --git a/Framework/Core/src/TableTreeHelpers.cxx b/Framework/Core/src/TableTreeHelpers.cxx index c20febaac517d..d0fdd0ced5779 100644 --- a/Framework/Core/src/TableTreeHelpers.cxx +++ b/Framework/Core/src/TableTreeHelpers.cxx @@ -512,7 +512,7 @@ void TreeToTable::addAllColumns(TTree* tree, std::vector&& names) if (strncmp(reader->branch()->GetName(), "fIndexArray", strlen("fIndexArray")) == 0) { std::string sizeBranchName = reader->branch()->GetName(); sizeBranchName += "_size"; - TBranch* sizeBranch = (TBranch*)tree->GetBranch(sizeBranchName.c_str()); + auto* sizeBranch = (TBranch*)tree->GetBranch(sizeBranchName.c_str()); if (sizeBranch) { tree->AddBranchToCache(sizeBranch); } diff --git a/Framework/Core/test/test_Root2ArrowTable.cxx b/Framework/Core/test/test_Root2ArrowTable.cxx index 599f1062c63a0..03f0977a4c0c4 100644 --- a/Framework/Core/test/test_Root2ArrowTable.cxx +++ b/Framework/Core/test/test_Root2ArrowTable.cxx @@ -358,4 +358,112 @@ TEST_CASE("RootTree2Dataset") REQUIRE(result.ok()); REQUIRE((*result)->columns().size() == 7); REQUIRE((*result)->num_rows() == 100); + + { + auto int_array = std::static_pointer_cast((*result)->GetColumnByName("ev")); + for (int64_t j = 0; j < int_array->length(); j++) { + REQUIRE(int_array->Value(j) == j + 1); + } + } + + { + auto list_array = std::static_pointer_cast((*result)->GetColumnByName("xyz")); + + // Iterate over the FixedSizeListArray + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + auto float_array = std::static_pointer_cast(value_slice); + + REQUIRE(float_array->Value(0) == 1); + REQUIRE(float_array->Value(1) == 2); + REQUIRE(float_array->Value(2) == i + 1); + } + } + + { + auto list_array = std::static_pointer_cast((*result)->GetColumnByName("ij")); + + // Iterate over the FixedSizeListArray + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + auto int_array = std::static_pointer_cast(value_slice); + REQUIRE(int_array->Value(0) == i); + REQUIRE(int_array->Value(1) == i + 1); + } + } + + auto* output = new TMemFile("foo", "RECREATE"); + auto outFs = std::make_shared(output, 0); + arrow::fs::FileLocator locator{outFs, "/DF_3"}; + + auto destination = outFs->OpenOutputStream(locator.path, {}); + REQUIRE(destination.ok()); + + auto writer = format->MakeWriter(*destination, schema, {}, locator); + auto success = writer->get()->Write(*result); + auto rootDestination = std::dynamic_pointer_cast(*destination); + + REQUIRE(success.ok()); + // Let's read it back... + arrow::dataset::FileSource source2("/DF_3", outFs); + auto newTreeFS = outFs->GetSubFilesystem(source2); + + REQUIRE(format->IsSupported(source) == true); + + auto schemaOptWritten = format->Inspect(source); + REQUIRE(schemaOptWritten.ok()); + auto schemaWritten = *schemaOptWritten; + REQUIRE(schemaWritten->num_fields() == 7); + REQUIRE(schemaWritten->field(0)->type()->id() == arrow::float32()->id()); + REQUIRE(schemaWritten->field(1)->type()->id() == arrow::float32()->id()); + REQUIRE(schemaWritten->field(2)->type()->id() == arrow::float32()->id()); + REQUIRE(schemaWritten->field(3)->type()->id() == arrow::float64()->id()); + REQUIRE(schemaWritten->field(4)->type()->id() == arrow::int32()->id()); + REQUIRE(schemaWritten->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id()); + REQUIRE(schemaWritten->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id()); + + auto fragmentWritten = format->MakeFragment(source, {}, schema); + REQUIRE(fragmentWritten.ok()); + auto optionsWritten = std::make_shared(); + options->dataset_schema = schemaWritten; + auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment); + REQUIRE(scannerWritten.ok()); + auto batchesWritten = (*scanner)(); + auto resultWritten = batches.result(); + REQUIRE(resultWritten.ok()); + REQUIRE((*resultWritten)->columns().size() == 7); + REQUIRE((*resultWritten)->num_rows() == 100); + + { + auto int_array = std::static_pointer_cast((*resultWritten)->GetColumnByName("ev")); + for (int64_t j = 0; j < int_array->length(); j++) { + REQUIRE(int_array->Value(j) == j + 1); + } + } + + { + auto list_array = std::static_pointer_cast((*result)->GetColumnByName("xyz")); + + // Iterate over the FixedSizeListArray + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + auto float_array = std::static_pointer_cast(value_slice); + + REQUIRE(float_array->Value(0) == 1); + REQUIRE(float_array->Value(1) == 2); + REQUIRE(float_array->Value(2) == i + 1); + } + } + + { + auto list_array = std::static_pointer_cast((*result)->GetColumnByName("ij")); + + // Iterate over the FixedSizeListArray + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + auto int_array = std::static_pointer_cast(value_slice); + REQUIRE(int_array->Value(0) == i); + REQUIRE(int_array->Value(1) == i + 1); + } + } }