Skip to content

Commit

Permalink
Compute and track detailed evaluation metrics on each epoch (#385)
Browse files Browse the repository at this point in the history
* save jet metrics on each epoch

* format

* fix ray
  • Loading branch information
jpata authored Dec 27, 2024
1 parent a35d94e commit d85a351
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 124 deletions.
7 changes: 5 additions & 2 deletions mlpf/model/PFDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,16 @@ def __getitem__(self, item):
ret["X"][:, 1][msk_ho] = np.sqrt(e**2 - (np.tanh(eta) * e) ** 2)

# transform pt -> log(pt / elem pt), same for energy
target_pt = np.log(ret["ytarget"][:, 2] / ret["X"][:, 1])
# where target does not exist, set to 0
with np.errstate(divide="ignore"):
target_pt = np.log(ret["ytarget"][:, 2] / ret["X"][:, 1])
target_pt[np.isnan(target_pt)] = 0
target_pt[np.isinf(target_pt)] = 0
ret["ytarget_pt_orig"] = ret["ytarget"][:, 2].copy()
ret["ytarget"][:, 2] = target_pt

target_e = np.log(ret["ytarget"][:, 6] / ret["X"][:, 5])
with np.errstate(divide="ignore"):
target_e = np.log(ret["ytarget"][:, 6] / ret["X"][:, 5])
target_e[ret["ytarget"][:, 0] == 0] = 0
target_e[np.isnan(target_e)] = 0
target_e[np.isinf(target_e)] = 0
Expand Down
3 changes: 2 additions & 1 deletion mlpf/model/distributed_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def train_ray_trial(config, args, outdir=None):
_logger.info(table)

if (rank == 0) or (rank == "cpu"):
save_HPs(args, model, model_kwargs, outdir) # save model_kwargs and hyperparameters
save_HPs(config, model, model_kwargs, outdir) # save model_kwargs and hyperparameters
_logger.info("Creating experiment dir {}".format(outdir))
_logger.info(f"Model directory {outdir}", color="bold")

Expand Down Expand Up @@ -312,6 +312,7 @@ def train_ray_trial(config, args, outdir=None):
config["num_epochs"],
config["patience"],
outdir,
config,
trainable=config["model"]["trainable"],
start_epoch=start_epoch,
lr_schedule=lr_schedule,
Expand Down
12 changes: 8 additions & 4 deletions mlpf/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,15 @@ def run_predictions(world_size, rank, model, loader, sample, outpath, jetdef, je
ti = time.time()
for i, batch in iterator:
predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_match_dr, outpath, dir_name, sample)
tf = time.time()
time_total_min = (tf - ti) / 60.0

_logger.info(f"Time taken to make predictions on device {rank} is: {((time.time() - ti) / 60):.2f} min")
_logger.info(f"Time taken to make predictions on device {rank} is: {time_total_min:.2f} min")


def make_plots(outpath, sample, dataset, dir_name="", ntest_files=-1):
"""Uses the predictions stored as .parquet files (see above) to make plots."""

"""Uses the predictions stored as .parquet files from run_predictions to make plots."""
ret_dict = {}
mplhep.style.use(mplhep.styles.CMS)
class_names = get_class_names(sample)
os.system(f"mkdir -p {outpath}/plots{dir_name}/{sample}")
Expand All @@ -181,7 +183,7 @@ def make_plots(outpath, sample, dataset, dir_name="", ntest_files=-1):
dataset=dataset,
sample=sample,
)
plot_jet_ratio(
ret_dict["jet_ratio"] = plot_jet_ratio(
yvals,
cp_dir=plots_path,
bins=np.linspace(0, 5, 500),
Expand Down Expand Up @@ -230,3 +232,5 @@ def make_plots(outpath, sample, dataset, dir_name="", ntest_files=-1):
plot_particles(yvals, cp_dir=plots_path, dataset=dataset, sample=sample)
plot_particle_ratio(yvals, class_names, cp_dir=plots_path, dataset=dataset, sample=sample)
plot_particle_response(X, yvals, class_names, cp_dir=plots_path, dataset=dataset, sample=sample)

return ret_dict
2 changes: 1 addition & 1 deletion mlpf/model/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
self.norm1 = torch.nn.LayerNorm(embedding_dim)
self.seq = torch.nn.Sequential(nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act())
self.dropout = torch.nn.Dropout(dropout_ff)
_logger.info("using attention_type={}".format(attention_type))
_logger.info("layer {} using attention_type={}".format(self.name, attention_type))
# params for torch sdp_kernel
if self.enable_ctx_manager:
self.attn_params = {
Expand Down
Loading

0 comments on commit d85a351

Please sign in to comment.