Skip to content

Commit

Permalink
[RNTuple] write multiple cols
Browse files Browse the repository at this point in the history
  • Loading branch information
Moelf committed Sep 28, 2024
1 parent 0679f73 commit 4d99b53
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 43 deletions.
64 changes: 43 additions & 21 deletions src/RNTuple/Writing/TFileWriter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,19 +473,45 @@ function rnt_write_observe(io::IO, x::T) where T
WriteObservable(io, pos, len, x)
end

function add_field_column_record!(field_records, column_records, input_T::Type{<:Real}, NAME; parent_field_id)
fr = UnROOT.FieldRecord(zero(UInt32), zero(UInt32), parent_field_id, zero(UInt16), zero(UInt16), 0, -1, -1, string(NAME), RNTUPLE_WRITE_TYPE_CPPNAME_DICT[input_T], "", "")
cr = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[input_T]..., parent_field_id, 0x00, 0x00, 0)
push!(field_records, fr)
push!(column_records, cr)
nothing
end

function schema_to_field_column_records(table)
input_schema = schema(table)
input_Ts = input_schema.types
input_names = input_schema.names
field_records = UnROOT.FieldRecord[]
column_records = UnROOT.ColumnRecord[]

for (input_T, input_name) in zip(input_Ts, input_names)
add_field_column_record!(field_records, column_records, input_T, input_name, parent_field_id=length(field_records))
end
return field_records, column_records
end

function generate_page_links(column_records, pages_obses, Nitems)
outer_list = RNTuplePageOuterList{RNTuplePageInnerList{PageDescription}}([])
for (cr, page_obs) in zip(column_records, pages_obses)
inner_list = RNTuplePageInnerList([
PageDescription(Nitems, Locator(div(cr.nbits * Nitems, 8, RoundUp), page_obs.position))
])
push!(outer_list, inner_list)
end
return RNTuplePageTopList([outer_list])
end

function write_rntuple(file::IO, table; file_name="test_ntuple_minimal.root", rntuple_name="myntuple")
if !istable(table)
error("RNTuple writing accepts object compatible with Tables.jl interface, got type $(typeof(table))")
end

input_schema = schema(table)
input_Ncols = length(input_schema.names)
if input_Ncols != 1
error("Currently, RNTuple writing only supports a single, UInt32 column, got $input_Ncols columns")
end
input_T = only(input_schema.types)
input_col = only(columntable(table))
input_length = length(input_col)
input_cols = columntable(table)
input_length = length(input_cols[begin])
if input_length > 65535
error("Input too long: RNTuple writing currently only supports a single page (65535 elements)")
end
Expand All @@ -507,28 +533,24 @@ function write_rntuple(file::IO, table; file_name="test_ntuple_minimal.root", rn

RBlob1_obs = rnt_write_observe(file, Stubs.RBlob1)
rntAnchor_update[:fSeekHeader] = UInt32(position(file))
rnt_header = UnROOT.RNTupleHeader(zero(UInt64), rntuple_name, "", "ROOT v6.33.01", [
UnROOT.FieldRecord(zero(UInt32), zero(UInt32), zero(UInt32), zero(UInt16), zero(UInt16), 0, -1, -1, string(only(input_schema.names)), RNTUPLE_WRITE_TYPE_CPPNAME_DICT[input_T], "", ""),
], [UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[input_T]..., zero(UInt32), 0x00, 0x00, 0),], UnROOT.AliasRecord[], UnROOT.ExtraTypeInfo[])
field_records, col_records = schema_to_field_column_records(table)
rnt_header = UnROOT.RNTupleHeader(
zero(UInt64), rntuple_name, "", "ROOT v6.33.01",
field_records, col_records,
UnROOT.AliasRecord[], UnROOT.ExtraTypeInfo[]
)

rnt_header_obs = rnt_write_observe(file, rnt_header)
rntAnchor_update[:fNBytesHeader] = rnt_header_obs.len
rntAnchor_update[:fLenHeader] = rnt_header_obs.len

RBlob2_obs = rnt_write_observe(file, Stubs.RBlob2)
page1 = rnt_ary_to_page(input_col)
page1_obs = rnt_write_observe(file, page1)
pages = [rnt_ary_to_page(col) for col in input_cols]
pages_obses = [rnt_write_observe(file, page) for page in pages]

RBlob3_obs = rnt_write_observe(file, Stubs.RBlob3)
cluster_summary = Write_RNTupleListFrame([ClusterSummary(0, input_length)])
nested_page_locations =
UnROOT.RNTuplePageTopList([
UnROOT.RNTuplePageOuterList([
UnROOT.RNTuplePageInnerList([
PageDescription(input_length, UnROOT.Locator(sizeof(input_T) * input_length, page1_obs.position, )),
]),
]),
])
nested_page_locations = generate_page_links(col_records, pages_obses, input_length)

pagelink = UnROOT.PageLink(_checksum(rnt_header_obs.object), cluster_summary.payload, nested_page_locations)
pagelink_obs = rnt_write_observe(file, pagelink)
Expand Down
2 changes: 2 additions & 0 deletions src/RNTuple/footer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ for x in (:RNTuplePageTopList, :RNTuplePageOuterList, :RNTuplePageInnerList)
Base.size(r::$x) = size(r.payload)
Base.getindex(r::$x, i) = r.payload[i]
Base.setindex!(r::$x, v, i) = (r.payload[i] = v)
Base.push!(r::$x, v) = push!(r.payload, v)
Base.append!(r::$x, v) = append!(r.payload, v)

end
end
Expand Down
2 changes: 1 addition & 1 deletion src/RNTuple/header.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct FieldRecord
Base.@kwdef struct FieldRecord
field_version::UInt32
type_version::UInt32
parent_field_id::UInt32
Expand Down
42 changes: 21 additions & 21 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@ using UnROOT
nthreads = UnROOT._maxthreadid()
nthreads == 1 && @warn "Running on a single thread. Please re-run the test suite with at least two threads (`julia --threads 2 ...`)"

@testset "UnROOT tests" verbose = true begin
include("Aqua.jl")
include("bootstrapping.jl")
include("compressions.jl")
include("jagged.jl")
include("lazy.jl")
include("histograms.jl")
include("views.jl")
include("multithreading.jl")
include("remote.jl")
include("displays.jl")
include("type_stability.jl")
include("utils.jl")
include("misc.jl")
# @testset "UnROOT tests" verbose = true begin
# include("Aqua.jl")
# include("bootstrapping.jl")
# include("compressions.jl")
# include("jagged.jl")
# include("lazy.jl")
# include("histograms.jl")
# include("views.jl")
# include("multithreading.jl")
# include("remote.jl")
# include("displays.jl")
# include("type_stability.jl")
# include("utils.jl")
# include("misc.jl")

include("type_support.jl")
include("custom_bootstrapping.jl")
include("lorentzvectors.jl")
include("NanoAOD.jl")
# include("type_support.jl")
# include("custom_bootstrapping.jl")
# include("lorentzvectors.jl")
# include("NanoAOD.jl")

include("issues.jl")
# include("issues.jl")

if VERSION >= v"1.9"
include("rntuple.jl")
# include("rntuple.jl")
include("./RNTupleWriting/lowlevel.jl")
end
end
# end

0 comments on commit 4d99b53

Please sign in to comment.