Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update athena dump reader #4011

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IReader.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
#include "ActsExamples/Io/Root/RootAthenaDumpGeoIdCollector.hpp"
#include <ActsExamples/EventData/Cluster.hpp>
#include <ActsExamples/EventData/SimParticle.hpp>
#include <ActsExamples/EventData/Track.hpp>
Expand Down Expand Up @@ -45,7 +46,7 @@ class RootAthenaDumpReader : public IReader {
// Name of tree
std::string treename;
// Name of inputfile
std::string inputfile;
std::vector<std::string> inputfiles;
// name of the output measurements
std::string outputMeasurements = "athena_measurements";
// name of the output pixel space points
Expand All @@ -58,14 +59,19 @@ class RootAthenaDumpReader : public IReader {
std::string outputClusters = "athena_clusters";
// name of the output particles
std::string outputParticles = "athena_particles";
// name of the simhit map
// name of the measurements -> particles map
std::string outputMeasurementParticlesMap = "athena_meas_parts_map";
// name of the particles -> measurements map
std::string outputParticleMeasurementsMap = "athena_parts_meas_map";
// name of the track parameters (fitted by athena?)
std::string outputTrackParameters = "athena_track_parameters";

/// Only extract spacepoints
bool onlySpacepoints = false;

/// Skip truth data
bool noTruth = false;

/// Only extract particles that passed the tracking requirements, for
/// details see:
/// https://gitlab.cern.ch/atlas/athena/-/blob/main/InnerDetector/InDetGNNTracking/src/DumpObjects.cxx?ref_type=heads#L1363
Expand Down Expand Up @@ -165,6 +171,8 @@ class RootAthenaDumpReader : public IReader {
this, "output_measurements"};
WriteDataHandle<IndexMultimap<ActsFatras::Barcode>> m_outputMeasParticleMap{
this, "output_meas_part_map"};
WriteDataHandle<InverseMultimap<ActsFatras::Barcode>> m_outputParticleMeasMap{
this, "output_part_meas_map"};

std::unique_ptr<const Acts::Logger> m_logger;
std::mutex m_read_mutex;
Expand Down
170 changes: 99 additions & 71 deletions Examples/Io/Root/src/RootAthenaDumpReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ RootAthenaDumpReader::RootAthenaDumpReader(
: IReader(),
m_cfg(config),
m_logger(Acts::getDefaultLogger(name(), level)) {
if (m_cfg.inputfile.empty()) {
throw std::invalid_argument("Missing input filename");
if (m_cfg.inputfiles.empty()) {
throw std::invalid_argument("Empty input file list");
}
if (m_cfg.treename.empty()) {
throw std::invalid_argument("Missing tree name");
Expand All @@ -77,10 +77,13 @@ RootAthenaDumpReader::RootAthenaDumpReader(
m_outputStripSpacePoints.initialize(m_cfg.outputStripSpacePoints);
m_outputSpacePoints.initialize(m_cfg.outputSpacePoints);
if (!m_cfg.onlySpacepoints) {
m_outputClusters.initialize(m_cfg.outputClusters);
m_outputParticles.initialize(m_cfg.outputParticles);
m_outputMeasParticleMap.initialize(m_cfg.outputMeasurementParticlesMap);
m_outputMeasurements.initialize(m_cfg.outputMeasurements);
if (!m_cfg.noTruth) {
m_outputClusters.initialize(m_cfg.outputClusters);
m_outputParticles.initialize(m_cfg.outputParticles);
m_outputMeasParticleMap.initialize(m_cfg.outputMeasurementParticlesMap);
m_outputParticleMeasMap.initialize(m_cfg.outputParticleMeasurementsMap);
}
}

if (m_inputchain->GetBranch("SPtopStripDirection") == nullptr) {
Expand Down Expand Up @@ -134,12 +137,6 @@ RootAthenaDumpReader::RootAthenaDumpReader(
m_inputchain->SetBranchAddress("CLphi_module", CLphi_module);
m_inputchain->SetBranchAddress("CLside", CLside);
m_inputchain->SetBranchAddress("CLmoduleID", CLmoduleID);
m_inputchain->SetBranchAddress("CLparticleLink_eventIndex",
&CLparticleLink_eventIndex);
m_inputchain->SetBranchAddress("CLparticleLink_barcode",
&CLparticleLink_barcode);
m_inputchain->SetBranchAddress("CLbarcodesLinked", &CLbarcodesLinked);
m_inputchain->SetBranchAddress("CLparticle_charge", &CLparticle_charge);
m_inputchain->SetBranchAddress("CLphis", &CLphis);
m_inputchain->SetBranchAddress("CLetas", &CLetas);
m_inputchain->SetBranchAddress("CLtots", &CLtots);
Expand All @@ -161,28 +158,40 @@ RootAthenaDumpReader::RootAthenaDumpReader(
m_inputchain->SetBranchAddress("CLnorm_y", CLnorm_y);
m_inputchain->SetBranchAddress("CLnorm_z", CLnorm_z);
m_inputchain->SetBranchAddress("CLlocal_cov", &CLlocal_cov);
m_inputchain->SetBranchAddress("nPartEVT", &nPartEVT);
m_inputchain->SetBranchAddress("Part_event_number", Part_event_number);
m_inputchain->SetBranchAddress("Part_barcode", Part_barcode);
m_inputchain->SetBranchAddress("Part_px", Part_px);
m_inputchain->SetBranchAddress("Part_py", Part_py);
m_inputchain->SetBranchAddress("Part_pz", Part_pz);
m_inputchain->SetBranchAddress("Part_pt", Part_pt);
m_inputchain->SetBranchAddress("Part_eta", Part_eta);
m_inputchain->SetBranchAddress("Part_vx", Part_vx);
m_inputchain->SetBranchAddress("Part_vy", Part_vy);
m_inputchain->SetBranchAddress("Part_vz", Part_vz);
m_inputchain->SetBranchAddress("Part_radius", Part_radius);
m_inputchain->SetBranchAddress("Part_status", Part_status);
m_inputchain->SetBranchAddress("Part_charge", Part_charge);
m_inputchain->SetBranchAddress("Part_pdg_id", Part_pdg_id);
m_inputchain->SetBranchAddress("Part_passed", Part_passed);
m_inputchain->SetBranchAddress("Part_vProdNin", Part_vProdNin);
m_inputchain->SetBranchAddress("Part_vProdNout", Part_vProdNout);
m_inputchain->SetBranchAddress("Part_vProdStatus", Part_vProdStatus);
m_inputchain->SetBranchAddress("Part_vProdBarcode", Part_vProdBarcode);
m_inputchain->SetBranchAddress("Part_vParentID", &Part_vParentID);
m_inputchain->SetBranchAddress("Part_vParentBarcode", &Part_vParentBarcode);
if (!m_cfg.noTruth) {
m_inputchain->SetBranchAddress("CLparticleLink_eventIndex",
&CLparticleLink_eventIndex);
m_inputchain->SetBranchAddress("CLparticleLink_barcode",
&CLparticleLink_barcode);
m_inputchain->SetBranchAddress("CLbarcodesLinked", &CLbarcodesLinked);
m_inputchain->SetBranchAddress("CLparticle_charge", &CLparticle_charge);
}

if (!m_cfg.noTruth) {
m_inputchain->SetBranchAddress("nPartEVT", &nPartEVT);
m_inputchain->SetBranchAddress("Part_event_number", Part_event_number);
m_inputchain->SetBranchAddress("Part_barcode", Part_barcode);
m_inputchain->SetBranchAddress("Part_px", Part_px);
m_inputchain->SetBranchAddress("Part_py", Part_py);
m_inputchain->SetBranchAddress("Part_pz", Part_pz);
m_inputchain->SetBranchAddress("Part_pt", Part_pt);
m_inputchain->SetBranchAddress("Part_eta", Part_eta);
m_inputchain->SetBranchAddress("Part_vx", Part_vx);
m_inputchain->SetBranchAddress("Part_vy", Part_vy);
m_inputchain->SetBranchAddress("Part_vz", Part_vz);
m_inputchain->SetBranchAddress("Part_radius", Part_radius);
m_inputchain->SetBranchAddress("Part_status", Part_status);
m_inputchain->SetBranchAddress("Part_charge", Part_charge);
m_inputchain->SetBranchAddress("Part_pdg_id", Part_pdg_id);
m_inputchain->SetBranchAddress("Part_passed", Part_passed);
m_inputchain->SetBranchAddress("Part_vProdNin", Part_vProdNin);
m_inputchain->SetBranchAddress("Part_vProdNout", Part_vProdNout);
m_inputchain->SetBranchAddress("Part_vProdStatus", Part_vProdStatus);
m_inputchain->SetBranchAddress("Part_vProdBarcode", Part_vProdBarcode);
m_inputchain->SetBranchAddress("Part_vParentID", &Part_vParentID);
m_inputchain->SetBranchAddress("Part_vParentBarcode", &Part_vParentBarcode);
}

m_inputchain->SetBranchAddress("nSP", &nSP);
m_inputchain->SetBranchAddress("SPindex", SPindex);
m_inputchain->SetBranchAddress("SPx", SPx);
Expand All @@ -191,19 +200,25 @@ RootAthenaDumpReader::RootAthenaDumpReader(
m_inputchain->SetBranchAddress("SPCL1_index", SPCL1_index);
m_inputchain->SetBranchAddress("SPCL2_index", SPCL2_index);
m_inputchain->SetBranchAddress("SPisOverlap", SPisOverlap);
m_inputchain->SetBranchAddress("SPradius", SPradius);
m_inputchain->SetBranchAddress("SPcovr", SPcovr);
m_inputchain->SetBranchAddress("SPcovz", SPcovz);
m_inputchain->SetBranchAddress("SPhl_topstrip", SPhl_topstrip);
m_inputchain->SetBranchAddress("SPhl_botstrip", SPhl_botstrip);
m_inputchain->SetBranchAddress("SPtopStripDirection", &SPtopStripDirection);
m_inputchain->SetBranchAddress("SPbottomStripDirection",
&SPbottomStripDirection);
m_inputchain->SetBranchAddress("SPstripCenterDistance",
&SPstripCenterDistance);
m_inputchain->SetBranchAddress("SPtopStripCenterPosition",
&SPtopStripCenterPosition);

if (m_haveStripFeatures) {
m_inputchain->SetBranchAddress("SPradius", SPradius);
m_inputchain->SetBranchAddress("SPcovr", SPcovr);
m_inputchain->SetBranchAddress("SPcovz", SPcovz);
m_inputchain->SetBranchAddress("SPhl_topstrip", SPhl_topstrip);
m_inputchain->SetBranchAddress("SPhl_botstrip", SPhl_botstrip);
m_inputchain->SetBranchAddress("SPtopStripDirection", &SPtopStripDirection);
m_inputchain->SetBranchAddress("SPbottomStripDirection",
&SPbottomStripDirection);
m_inputchain->SetBranchAddress("SPstripCenterDistance",
&SPstripCenterDistance);
m_inputchain->SetBranchAddress("SPtopStripCenterPosition",
&SPtopStripCenterPosition);
}

// These quantities are not used currently and thus commented out
// I would like to keep the code, since it is always a pain to write it
/*
m_inputchain->SetBranchAddress("nTRK", &nTRK);
m_inputchain->SetBranchAddress("TRKindex", TRKindex);
m_inputchain->SetBranchAddress("TRKtrack_fitter", TRKtrack_fitter);
Expand Down Expand Up @@ -239,9 +254,12 @@ RootAthenaDumpReader::RootAthenaDumpReader(
&DTTstTrack_subDetType);
m_inputchain->SetBranchAddress("DTTstCommon_subDetType",
&DTTstCommon_subDetType);
*/

m_inputchain->Add(m_cfg.inputfile.c_str());
ACTS_DEBUG("Adding file " << m_cfg.inputfile << " to tree" << m_cfg.treename);
for (const auto& file : m_cfg.inputfiles) {
m_inputchain->Add(file.c_str());
ACTS_DEBUG("Adding file '" << file << "' to tree " << m_cfg.treename);
}

m_events = m_inputchain->GetEntries();

Expand Down Expand Up @@ -447,25 +465,26 @@ RootAthenaDumpReader::readMeasurements(
}

std::size_t measIndex = measurements.size();
imIdxMap.emplace(im, measIndex);
createMeasurement(measurements, geoId, digiPars);

// Create measurement particles map and particles container
for (const auto& [subevt, barcode] :
Acts::zip(CLparticleLink_eventIndex->at(im),
CLparticleLink_barcode->at(im))) {
auto dummyBarcode = concatInts(barcode, subevt);
// If we don't find the particle, create one with default values
if (particles.find(dummyBarcode) == particles.end()) {
ACTS_VERBOSE("Particle with subevt " << subevt << ", barcode "
<< barcode
<< "not found, create dummy one");
particles.emplace(dummyBarcode, Acts::PdgParticle::eInvalid);
if (!m_cfg.noTruth) {
// Create measurement particles map and particles container
for (const auto& [subevt, barcode] :
Acts::zip(CLparticleLink_eventIndex->at(im),
CLparticleLink_barcode->at(im))) {
auto dummyBarcode = concatInts(barcode, subevt);
// If we don't find the particle, create one with default values
if (particles.find(dummyBarcode) == particles.end()) {
ACTS_VERBOSE("Particle with subevt "
<< subevt << ", barcode " << barcode
<< "not found, create dummy one");
particles.emplace(dummyBarcode, Acts::PdgParticle::eInvalid);
}
measPartMap.insert(
std::pair<Index, ActsFatras::Barcode>{measIndex, dummyBarcode});
}
measPartMap.insert(
std::pair<Index, ActsFatras::Barcode>{measIndex, dummyBarcode});
}

imIdxMap.emplace(im, measIndex);
}

if (measurements.size() < static_cast<std::size_t>(nCL)) {
Expand All @@ -479,8 +498,8 @@ RootAthenaDumpReader::readMeasurements(
}

if (nTotalTotZero > 0) {
ACTS_WARNING(nTotalTotZero << " / " << nCL
<< " clusters have zero time-over-threshold");
ACTS_DEBUG(nTotalTotZero << " / " << nCL
<< " clusters have zero time-over-threshold");
}

return {clusters, measurements, measPartMap, imIdxMap};
Expand Down Expand Up @@ -618,7 +637,8 @@ RootAthenaDumpReader::readSpacepoints(
ACTS_DEBUG("Skipped " << skippedSpacePoints
<< " because of eta/phi overlaps");
}
if (spacePoints.size() < static_cast<std::size_t>(nSP)) {
if (spacePoints.size() <
(static_cast<std::size_t>(nSP) - skippedSpacePoints)) {
ACTS_WARNING("Could not convert " << nSP - spacePoints.size() << " of "
<< nSP << " spacepoints");
}
Expand Down Expand Up @@ -701,19 +721,27 @@ ProcessCode RootAthenaDumpReader::read(const AlgorithmContext& ctx) {
std::optional<std::unordered_map<int, std::size_t>> optImIdxMap;

if (!m_cfg.onlySpacepoints) {
auto candidateParticles = readParticles();
SimParticleContainer candidateParticles;

if (!m_cfg.noTruth) {
candidateParticles = readParticles();
}

auto [clusters, measurements, candidateMeasPartMap, imIdxMap] =
readMeasurements(candidateParticles, ctx.geoContext);
optImIdxMap.emplace(std::move(imIdxMap));

auto [particles, measPartMap] =
reprocessParticles(candidateParticles, candidateMeasPartMap);

m_outputClusters(ctx, std::move(clusters));
m_outputParticles(ctx, std::move(particles));
m_outputMeasParticleMap(ctx, std::move(measPartMap));
m_outputMeasurements(ctx, std::move(measurements));

if (!m_cfg.noTruth) {
auto [particles, measPartMap] =
reprocessParticles(candidateParticles, candidateMeasPartMap);

m_outputParticles(ctx, std::move(particles));
m_outputParticleMeasMap(ctx, invertIndexMultimap(measPartMap));
m_outputMeasParticleMap(ctx, std::move(measPartMap));
}
}

auto [spacePoints, pixelSpacePoints, stripSpacePoints] =
Expand Down
9 changes: 5 additions & 4 deletions Examples/Python/src/Input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ void addInput(Context& ctx) {

ACTS_PYTHON_DECLARE_READER(
ActsExamples::RootAthenaDumpReader, mex, "RootAthenaDumpReader", treename,
inputfile, outputMeasurements, outputPixelSpacePoints,
inputfiles, outputMeasurements, outputPixelSpacePoints,
outputStripSpacePoints, outputSpacePoints, outputClusters,
outputMeasurementParticlesMap, outputParticles, onlyPassedParticles,
skipOverlapSPsPhi, skipOverlapSPsEta, geometryIdMap, trackingGeometry,
absBoundaryTolerance);
outputMeasurementParticlesMap, outputParticleMeasurementsMap,
outputParticles, onlyPassedParticles, skipOverlapSPsPhi,
skipOverlapSPsEta, geometryIdMap, trackingGeometry, absBoundaryTolerance,
onlySpacepoints, noTruth);

#ifdef WITH_GEOMODEL_PLUGIN
ACTS_PYTHON_DECLARE_READER(ActsExamples::RootAthenaDumpGeoIdCollector, mex,
Expand Down
Loading