Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Common,DPG] PID: Add AO2D metadata handling for pass name in TPC CCDB calls (pull/6793), adding NN-version number , using nSigmaTOFdautrack in protons from Lambda decays #8835

Merged
merged 8 commits into from
Dec 10, 2024
74 changes: 56 additions & 18 deletions Common/TableProducer/PID/pidTPC.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,32 @@
/// \brief Task to produce PID tables for TPC split for each particle.
/// Only the tables for the mass hypotheses requested are filled, and only for the requested table size ("Full" or "Tiny"). The others are sent empty.
///

#include <utility>
#include <map>
#include <memory>
#include <string>
#include <vector>
// ROOT includes
#include "TFile.h"
#include "TRandom.h"
#include "TSystem.h"

// O2 includes
#include <CCDB/BasicCCDBManager.h>
#include "CCDB/BasicCCDBManager.h"
#include "Framework/AnalysisTask.h"
#include "Framework/runDataProcessing.h"
#include "Framework/ASoAHelpers.h"
#include "ReconstructionDataFormats/Track.h"
#include "CCDB/CcdbApi.h"
#include "Common/DataModel/PIDResponse.h"
#include "Common/Core/PID/TPCPIDResponse.h"
#include "Framework/AnalysisDataModel.h"
#include "Common/DataModel/Multiplicity.h"
#include "Common/DataModel/EventSelection.h"
#include "TableHelper.h"
#include "Tools/ML/model.h"
#include "pidTPCBase.h"
#include "MetadataHelper.h"

using namespace o2;
using namespace o2::framework;
Expand All @@ -45,21 +53,21 @@
using namespace o2::track;
using namespace o2::ml;

MetadataHelper metadataInfo; // Metadata helper

void customize(std::vector<o2::framework::ConfigParamSpec>& workflowOptions)
{
std::vector<ConfigParamSpec> options{{"add-qa", VariantType::Int, 0, {"Legacy. No effect."}}};
std::swap(workflowOptions, options);
}

#include "Framework/runDataProcessing.h"

/// Task to produce the response table
struct tpcPid {
using Trks = soa::Join<aod::Tracks, aod::TracksExtra>;
using Coll = soa::Join<aod::Collisions, aod::PIDMults>;
using Coll = soa::Join<aod::Collisions, aod::PIDMults, aod::EvSels>;

using TrksMC = soa::Join<aod::Tracks, aod::TracksExtra, aod::McTrackLabels>;
using CollMC = soa::Join<aod::Collisions, aod::PIDMults, aod::McCollisionLabels>;
using CollMC = soa::Join<aod::Collisions, aod::PIDMults, aod::McCollisionLabels, aod::EvSels>;

// Tables to produce
Produces<o2::aod::pidTPCFullEl> tablePIDFullEl;
Expand Down Expand Up @@ -90,8 +98,10 @@
OnnxModel network;
o2::ccdb::CcdbApi ccdbApi;
std::map<std::string, std::string> metadata;
std::map<std::string, std::string> nullmetadata;
std::map<std::string, std::string> headers;
std::vector<int> speciesNetworkFlags = std::vector<int>(9);
std::string networkVersion;

// Input parameters
Service<o2::ccdb::BasicCCDBManager> ccdb;
Expand Down Expand Up @@ -187,11 +197,14 @@
speciesNetworkFlags[7] = useNetworkHe;
speciesNetworkFlags[8] = useNetworkAl;

// Initialise metadata object for CCDB calls
// Initialise metadata object for CCDB calls from AO2D metadata
if (recoPass.value == "") {
LOGP(info, "Reco pass not specified; CCDB will take latest available object");
if (metadataInfo.isFullyDefined()) {
metadata["RecoPassName"] = metadataInfo.get("RecoPassName");
LOGP(info, "Automatically setting reco pass for TPC Response to {} from AO2D", metadata["RecoPassName"]);
}
} else {
LOGP(info, "CCDB object will be requested for reconstruction pass {}", recoPass.value);
LOGP(info, "Setting reco pass for TPC response to user-defined name {}", recoPass.value);
metadata["RecoPassName"] = recoPass.value;
}

Expand All @@ -215,17 +228,23 @@
ccdb->setCaching(true);
ccdb->setLocalObjectValidityChecking();
ccdb->setCreatedNotAfter(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
ccdbApi.init(url);
if (time != 0) {
LOGP(info, "Initialising TPC PID response for fixed timestamp {} and reco pass {}:", time, recoPass.value);
ccdb->setTimestamp(time);
response = ccdb->getSpecific<o2::pid::tpc::Response>(path, time, metadata);
headers = ccdbApi.retrieveHeaders(path, metadata, time);
if (!response) {
LOGF(warning, "Unable to find TPC parametrisation for specified pass name - falling back to latest object");
response = ccdb->getForTimeStamp<o2::pid::tpc::Response>(path, time);
headers = ccdbApi.retrieveHeaders(path, metadata, time);
networkVersion = headers["NN-Version"];
if (!response) {
LOGF(fatal, "Unable to find any TPC object corresponding to timestamp {}!", time);
}
}
LOG(info) << "Successfully retrieved TPC PID object from CCDB for timestamp " << time << ", period " << headers["LPMProductionTag"] << ", recoPass " << headers["RecoPassName"];
metadata["RecoPassName"] = headers["RecoPassName"]; // Force pass number for NN request to match retrieved BB
response->PrintAll();
}
}
Expand All @@ -236,19 +255,21 @@
return;
} else {
/// CCDB and auto-fetching
ccdbApi.init(url);

if (!autofetchNetworks) {
if (ccdbTimestamp > 0) {
/// Fetching network for specific timestamp
LOG(info) << "Fetching network for timestamp: " << ccdbTimestamp.value;
bool retrieveSuccess = ccdbApi.retrieveBlob(networkPathCCDB.value, ".", metadata, ccdbTimestamp.value, false, networkPathLocally.value);
headers = ccdbApi.retrieveHeaders(networkPathCCDB.value, metadata, ccdbTimestamp.value);
networkVersion = headers["NN-Version"];
if (retrieveSuccess) {
network.initModel(networkPathLocally.value, enableNetworkOptimizations.value, networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0));
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
network.evalModel(dummyInput); /// Init the model evaluations
LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, and NN-Version {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]);
} else {
LOG(fatal) << "Error encountered while fetching/loading the network from CCDB! Maybe the network doesn't exist yet for this runnumber/timestamp?";
LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available";
}
} else {
/// Taking the network from local file
Expand All @@ -266,16 +287,16 @@
}
}

Partition<Trks> notTPCStandaloneTracks = (aod::track::tpcNClsFindable > (uint8_t)0) && ((aod::track::itsClusterSizes > (uint32_t)0) || (aod::track::trdPattern > (uint8_t)0) || (aod::track::tofExpMom > 0.f && aod::track::tofChi2 > 0.f)); // To count number of tracks for use in NN array
Partition<Trks> notTPCStandaloneTracks = (aod::track::tpcNClsFindable > static_cast<uint8_t>(0)) && ((aod::track::itsClusterSizes > static_cast<uint32_t>(0)) || (aod::track::trdPattern > static_cast<uint8_t>(0)) || (aod::track::tofExpMom > 0.f && aod::track::tofChi2 > 0.f)); // To count number of tracks for use in NN array
Partition<Trks> tracksWithTPC = (aod::track::tpcNClsFindable > (uint8_t)0);

template <typename C, typename T, typename B>
std::vector<float> createNetworkPrediction(C const& collisions, T const& tracks, B const& bcs, const size_t size)
{

std::vector<float> network_prediction;

Check warning on line 297 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.

auto start_network_total = std::chrono::high_resolution_clock::now();

Check warning on line 299 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
if (autofetchNetworks) {
const auto& bc = bcs.begin();
// Initialise correct TPC response object before NN setup (for NCl normalisation)
Expand All @@ -286,43 +307,50 @@
LOGP(info, "Retrieving TPC Response for timestamp {} and recoPass {}:", bc.timestamp(), recoPass.value);
}
response = ccdb->getSpecific<o2::pid::tpc::Response>(ccdbPath.value, bc.timestamp(), metadata);
headers = ccdbApi.retrieveHeaders(ccdbPath.value, metadata, bc.timestamp());
networkVersion = headers["NN-Version"];
if (!response) {
LOGP(warning, "!! Could not find a valid TPC response object for specific pass name {}! Falling back to latest uploaded object.", recoPass.value);
LOGP(warning, "!! Could not find a valid TPC response object for specific pass name {}! Falling back to latest uploaded object.", metadata["RecoPassName"]);
headers = ccdbApi.retrieveHeaders(ccdbPath.value, nullmetadata, bc.timestamp());
response = ccdb->getForTimeStamp<o2::pid::tpc::Response>(ccdbPath.value, bc.timestamp());
if (!response) {
LOGP(fatal, "Could not find ANY TPC response object for the timestamp {}!", bc.timestamp());
}
}
LOG(info) << "Successfully retrieved TPC PID object from CCDB for timestamp " << bc.timestamp() << ", period " << headers["LPMProductionTag"] << ", recoPass " << headers["RecoPassName"];
metadata["RecoPassName"] = headers["RecoPassName"]; // Force pass number for NN request to match retrieved BB
response->PrintAll();
}

if (bc.timestamp() < network.getValidityFrom() || bc.timestamp() > network.getValidityUntil()) { // fetches network only if the runnumbers change
LOG(info) << "Fetching network for timestamp: " << bc.timestamp();
bool retrieveSuccess = ccdbApi.retrieveBlob(networkPathCCDB.value, ".", metadata, bc.timestamp(), false, networkPathLocally.value);
headers = ccdbApi.retrieveHeaders(networkPathCCDB.value, metadata, bc.timestamp());
networkVersion = headers["NN-Version"];
if (retrieveSuccess) {
network.initModel(networkPathLocally.value, enableNetworkOptimizations.value, networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0));
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
network.evalModel(dummyInput);
LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, NN-Version number{}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]);
} else {
LOG(fatal) << "Error encountered while fetching/loading the network from CCDB! Maybe the network doesn't exist yet for this runnumber/timestamp?";
LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available";
}
}
}

// Defining some network parameters
int input_dimensions = network.getNumInputNodes();

Check warning on line 342 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
int output_dimensions = network.getNumOutputNodes();

Check warning on line 343 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
const uint64_t track_prop_size = input_dimensions * size;

Check warning on line 344 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
const uint64_t prediction_size = output_dimensions * size;

Check warning on line 345 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.

network_prediction = std::vector<float>(prediction_size * 9); // For each mass hypotheses
const float nNclNormalization = response->GetNClNormalization();
float duration_network = 0;

Check warning on line 349 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.

std::vector<float> track_properties(track_prop_size);

Check warning on line 351 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
uint64_t counter_track_props = 0;

Check warning on line 352 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.
int loop_counter = 0;

Check warning on line 353 in Common/TableProducer/PID/pidTPC.cxx

View workflow job for this annotation

GitHub Actions / O2 linter

[name/function-variable]

Use lowerCamelCase for names of functions and variables.

// Filling a std::vector<float> to be evaluated by the network
// Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
Expand All @@ -342,6 +370,9 @@
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[i];
track_properties[counter_track_props + 4] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).multTPC() / 11000. : 1.;
track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound());
if (input_dimensions == 7 && networkVersion == "2") {
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
}
counter_track_props += input_dimensions;
}

Expand Down Expand Up @@ -483,13 +514,16 @@
LOGP(info, "Retrieving TPC Response for timestamp {} and recoPass {}:", bc.timestamp(), recoPass.value);
}
response = ccdb->getSpecific<o2::pid::tpc::Response>(ccdbPath.value, bc.timestamp(), metadata);
headers = ccdbApi.retrieveHeaders(ccdbPath.value, metadata, bc.timestamp());
if (!response) {
LOGP(warning, "!! Could not find a valid TPC response object for specific pass name {}! Falling back to latest uploaded object.", recoPass.value);
LOGP(warning, "!! Could not find a valid TPC response object for specific pass name {}! Falling back to latest uploaded object.", metadata["RecoPassName"]);
response = ccdb->getForTimeStamp<o2::pid::tpc::Response>(ccdbPath.value, bc.timestamp());
headers = ccdbApi.retrieveHeaders(ccdbPath.value, nullmetadata, bc.timestamp());
if (!response) {
LOGP(fatal, "Could not find ANY TPC response object for the timestamp {}!", bc.timestamp());
}
}
LOG(info) << "Successfully retrieved TPC PID object from CCDB for timestamp " << bc.timestamp() << ", period " << headers["LPMProductionTag"] << ", recoPass " << headers["RecoPassName"];
response->PrintAll();
}

Expand All @@ -515,7 +549,7 @@

PROCESS_SWITCH(tpcPid, processStandard, "Creating PID tables without MC TuneOnData", true);

Partition<TrksMC> mcnotTPCStandaloneTracks = (aod::track::tpcNClsFindable > (uint8_t)0) && ((aod::track::itsClusterSizes > (uint32_t)0) || (aod::track::trdPattern > (uint8_t)0) || (aod::track::tofExpMom > 0.f && aod::track::tofChi2 > 0.f)); // To count number of tracks for use in NN array
Partition<TrksMC> mcnotTPCStandaloneTracks = (aod::track::tpcNClsFindable > static_cast<uint8_t>(0)) && ((aod::track::itsClusterSizes > static_cast<uint32_t>(0)) || (aod::track::trdPattern > static_cast<uint8_t>(0)) || (aod::track::tofExpMom > 0.f && aod::track::tofChi2 > 0.f)); // To count number of tracks for use in NN array
Partition<TrksMC> mctracksWithTPC = (aod::track::tpcNClsFindable > (uint8_t)0);

void processMcTuneOnData(CollMC const& collisionsMc, TrksMC const& tracksMc, aod::BCsWithTimestamps const& bcs, aod::McParticles const&)
Expand Down Expand Up @@ -573,7 +607,7 @@
}
response = ccdb->getSpecific<o2::pid::tpc::Response>(ccdbPath.value, bc.timestamp(), metadata);
if (!response) {
LOGP(warning, "!! Could not find a valid TPC response object for specific pass name {}! Falling back to latest uploaded object.", recoPass.value);
LOGP(warning, "!! Could not find a valid TPC response object for specific pass name {}! Falling back to latest uploaded object.", metadata["RecoPassName"]);
response = ccdb->getForTimeStamp<o2::pid::tpc::Response>(ccdbPath.value, bc.timestamp());
if (!response) {
LOGP(fatal, "Could not find ANY TPC response object for the timestamp {}!", bc.timestamp());
Expand Down Expand Up @@ -641,4 +675,8 @@
PROCESS_SWITCH(tpcPid, processMcTuneOnData, "Creating PID tables with MC TuneOnData", false);
};

WorkflowSpec defineDataProcessing(ConfigContext const& cfgc) { return WorkflowSpec{adaptAnalysisTask<tpcPid>(cfgc)}; }
WorkflowSpec defineDataProcessing(ConfigContext const& cfgc)
{
metadataInfo.initMetadata(cfgc); // Parse AO2D metadata
return WorkflowSpec{adaptAnalysisTask<tpcPid>(cfgc)};
}
10 changes: 7 additions & 3 deletions DPG/Tasks/TPC/tpcSkimsTableCreator.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct TreeWriterTpcV0 {
Produces<o2::aod::SkimmedTPCV0Tree> rowTPCTree;

/// Configurables
Configurable<float> nSigmaTOFdautrack{"nSigmaTOFdautrack", 5., "n-sigma TOF cut on the daughter tracks. Set 0 to switch it off."};
Configurable<float> nSigmaTOFdautrack{"nSigmaTOFdautrack", 999., "n-sigma TOF cut on the proton daughter tracks. Set 999 to switch it off."};
Configurable<float> nClNorm{"nClNorm", 152., "Number of cluster normalization. Run 2: 159, Run 3 152"};
Configurable<int> applyEvSel{"applyEvSel", 2, "Flag to apply rapidity cut: 0 -> no event selection, 1 -> Run 2 event selection, 2 -> Run 3 event selection"};
Configurable<int> trackSelection{"trackSelection", 1, "Track selection: 0 -> No Cut, 1 -> kGlobalTrack, 2 -> kGlobalTrackWoPtEta, 3 -> kGlobalTrackWoDCA, 4 -> kQualityTracks, 5 -> kInAcceptanceTracks"};
Expand Down Expand Up @@ -222,7 +222,9 @@ struct TreeWriterTpcV0 {
// Lambda
if (static_cast<bool>(posTrack.pidbit() & (1 << 2)) && static_cast<bool>(negTrack.pidbit() & (1 << 2))) {
if (downsampleTsalisCharged(posTrack.pt(), downsamplingTsalisProtons, sqrtSNN, o2::track::pid_constants::sMasses[o2::track::PID::Proton], maxPt4dwnsmplTsalisProtons)) {
fillSkimmedV0Table(v0, posTrack, collision, posTrack.tpcNSigmaPr(), posTrack.tofNSigmaPr(), posTrack.tpcExpSignalPr(posTrack.tpcSignal()), o2::track::PID::Proton, runnumber, dwnSmplFactor_Pr, hadronicRate);
if (TMath::Abs(posTrack.tofNSigmaPr()) <= nSigmaTOFdautrack) {
fillSkimmedV0Table(v0, posTrack, collision, posTrack.tpcNSigmaPr(), posTrack.tofNSigmaPr(), posTrack.tpcExpSignalPr(posTrack.tpcSignal()), o2::track::PID::Proton, runnumber, dwnSmplFactor_Pr, hadronicRate);
}
}
if (downsampleTsalisCharged(negTrack.pt(), downsamplingTsalisPions, sqrtSNN, o2::track::pid_constants::sMasses[o2::track::PID::Pion], maxPt4dwnsmplTsalisPions)) {
fillSkimmedV0Table(v0, negTrack, collision, negTrack.tpcNSigmaPi(), negTrack.tofNSigmaPi(), negTrack.tpcExpSignalPi(negTrack.tpcSignal()), o2::track::PID::Pion, runnumber, dwnSmplFactor_Pi, hadronicRate);
Expand All @@ -234,7 +236,9 @@ struct TreeWriterTpcV0 {
fillSkimmedV0Table(v0, posTrack, collision, posTrack.tpcNSigmaPi(), posTrack.tofNSigmaPi(), posTrack.tpcExpSignalPi(posTrack.tpcSignal()), o2::track::PID::Pion, runnumber, dwnSmplFactor_Pi, hadronicRate);
}
if (downsampleTsalisCharged(negTrack.pt(), downsamplingTsalisProtons, sqrtSNN, o2::track::pid_constants::sMasses[o2::track::PID::Proton], maxPt4dwnsmplTsalisProtons)) {
fillSkimmedV0Table(v0, negTrack, collision, negTrack.tpcNSigmaPr(), negTrack.tofNSigmaPr(), negTrack.tpcExpSignalPr(negTrack.tpcSignal()), o2::track::PID::Proton, runnumber, dwnSmplFactor_Pr, hadronicRate);
if (TMath::Abs(negTrack.tofNSigmaPr()) <= nSigmaTOFdautrack) {
fillSkimmedV0Table(v0, negTrack, collision, negTrack.tpcNSigmaPr(), negTrack.tofNSigmaPr(), negTrack.tpcExpSignalPr(negTrack.tpcSignal()), o2::track::PID::Proton, runnumber, dwnSmplFactor_Pr, hadronicRate);
}
}
}
}
Expand Down
Loading