Skip to content

Commit

Permalink
Use per-particle pt weight in loss (#383)
Browse files Browse the repository at this point in the history
* enable pt weight

* retrain CLIC model

* sync with CMS model

* use NANO for CMSSW evals

* update cmssw notebook to nano

* fix dict_keys

* some fixes for distributed

* add scaler

* format

* save plots

* add k0

* update notebook

* update notebook

* update timing

* add triton timing script

---------

Co-authored-by: Joosep Pata <[email protected]>
  • Loading branch information
jpata and Joosep Pata authored Jan 15, 2025
1 parent d85a351 commit 6e852b0
Show file tree
Hide file tree
Showing 32 changed files with 859 additions and 500 deletions.
15 changes: 11 additions & 4 deletions mlpf/model/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch.nn import functional as F
from torch import Tensor, nn

from mlpf.model.logger import _logger


def sliced_wasserstein_loss(y_pred, y_true, num_projections=200):
# create normalized random basis vectors
Expand Down Expand Up @@ -74,9 +76,9 @@ def mlpf_loss(y, ypred, batch):
loss_regression_energy[batch.mask == 0] *= 0

# add weight based on target pt
# sqrt_target_pt = torch.sqrt(torch.exp(y["pt"]) * batch.X[:, :, 1])
# loss_regression_pt *= sqrt_target_pt
# loss_regression_energy *= sqrt_target_pt
sqrt_target_pt = torch.sqrt(torch.exp(y["pt"]) * batch.X[:, :, 1])
loss_regression_pt *= sqrt_target_pt
loss_regression_energy *= sqrt_target_pt

# average over all target particles
loss["Regression_pt"] = loss_regression_pt.sum() / npart
Expand Down Expand Up @@ -122,10 +124,15 @@ def mlpf_loss(y, ypred, batch):
+ loss["Regression_energy"]
)
loss_opt = loss["Total"]
if torch.isnan(loss_opt):
_logger.error(ypred)
_logger.error(sqrt_target_pt)
_logger.error(loss)
raise Exception("Loss became NaN")

# store these separately but detached
for k in loss.keys():
loss[k] = loss[k].detach().cpu().item()
loss[k] = loss[k].detach()

return loss_opt, loss

Expand Down
90 changes: 48 additions & 42 deletions mlpf/model/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,45 +123,51 @@ def validation_plots(batch, ypred_raw, ytarget, ypred, tensorboard_writer, epoch
plt.xlabel("particle proba")
tensorboard_writer.add_figure("sig_proba_elemtype{}".format(int(xcls)), fig, global_step=epoch)

tensorboard_writer.add_histogram("pt_target", torch.clamp(batch.ytarget[batch.mask][:, 2], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("pt_pred", torch.clamp(ypred_raw[2][batch.mask][:, 0], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 0] / batch.ytarget[batch.mask][:, 2])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("pt_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("eta_target", torch.clamp(batch.ytarget[batch.mask][:, 3], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("eta_pred", torch.clamp(ypred_raw[2][batch.mask][:, 1], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 1] / batch.ytarget[batch.mask][:, 3])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("eta_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("sphi_target", torch.clamp(batch.ytarget[batch.mask][:, 4], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("sphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 2], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 2] / batch.ytarget[batch.mask][:, 4])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("sphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("cphi_target", torch.clamp(batch.ytarget[batch.mask][:, 5], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("cphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 3], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 3] / batch.ytarget[batch.mask][:, 5])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("cphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("energy_target", torch.clamp(batch.ytarget[batch.mask][:, 6], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("energy_pred", torch.clamp(ypred_raw[2][batch.mask][:, 4], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 4] / batch.ytarget[batch.mask][:, 6])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("energy_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

for attn in sorted(list(glob.glob(f"{outdir}/attn_conv_*.npz"))):
attn_name = os.path.basename(attn).split(".")[0]
attn_matrix = np.load(attn)["att"]
batch_size = min(attn_matrix.shape[0], 8)
fig, axes = plt.subplots(1, batch_size, figsize=((batch_size * 3, 1 * 3)))
if isinstance(axes, matplotlib.axes._axes.Axes):
axes = [axes]
for ibatch in range(batch_size):
plt.sca(axes[ibatch])
# plot the attention matrix of the first event in the batch
plt.imshow(attn_matrix[ibatch].T, cmap="hot", norm=matplotlib.colors.LogNorm())
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title("event {}, m={:.2E}".format(ibatch, np.mean(attn_matrix[ibatch][attn_matrix[ibatch] > 0])))
plt.suptitle(attn_name)
tensorboard_writer.add_figure(attn_name, fig, global_step=epoch)
try:
tensorboard_writer.add_histogram("pt_target", torch.clamp(batch.ytarget[batch.mask][:, 2], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("pt_pred", torch.clamp(ypred_raw[2][batch.mask][:, 0], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 0] / batch.ytarget[batch.mask][:, 2])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("pt_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("eta_target", torch.clamp(batch.ytarget[batch.mask][:, 3], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("eta_pred", torch.clamp(ypred_raw[2][batch.mask][:, 1], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 1] / batch.ytarget[batch.mask][:, 3])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("eta_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("sphi_target", torch.clamp(batch.ytarget[batch.mask][:, 4], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("sphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 2], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 2] / batch.ytarget[batch.mask][:, 4])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("sphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("cphi_target", torch.clamp(batch.ytarget[batch.mask][:, 5], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("cphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 3], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 3] / batch.ytarget[batch.mask][:, 5])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("cphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)

tensorboard_writer.add_histogram("energy_target", torch.clamp(batch.ytarget[batch.mask][:, 6], -10, 10), global_step=epoch)
tensorboard_writer.add_histogram("energy_pred", torch.clamp(ypred_raw[2][batch.mask][:, 4], -10, 10), global_step=epoch)
ratio = (ypred_raw[2][batch.mask][:, 4] / batch.ytarget[batch.mask][:, 6])[batch.ytarget[batch.mask][:, 0] != 0]
tensorboard_writer.add_histogram("energy_ratio", torch.clamp(ratio, -10, 10), global_step=epoch)
except ValueError as e:
print(e)

try:
for attn in sorted(list(glob.glob(f"{outdir}/attn_conv_*.npz"))):
attn_name = os.path.basename(attn).split(".")[0]
attn_matrix = np.load(attn)["att"]
batch_size = min(attn_matrix.shape[0], 8)
fig, axes = plt.subplots(1, batch_size, figsize=((batch_size * 3, 1 * 3)))
if isinstance(axes, matplotlib.axes._axes.Axes):
axes = [axes]
for ibatch in range(batch_size):
plt.sca(axes[ibatch])
# plot the attention matrix of the first event in the batch
plt.imshow(attn_matrix[ibatch].T, cmap="hot", norm=matplotlib.colors.LogNorm())
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.title("event {}, m={:.2E}".format(ibatch, np.mean(attn_matrix[ibatch][attn_matrix[ibatch] > 0])))
plt.suptitle(attn_name)
tensorboard_writer.add_figure(attn_name, fig, global_step=epoch)
except ValueError as e:
print(e)
104 changes: 42 additions & 62 deletions mlpf/model/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,57 +67,26 @@ def configure_model_trainable(model: MLPF, trainable: Union[str, List[str]], is_
model.eval()


def train_step(batch, model, optimizer, lr_schedule, loss_fn):
"""Single training step logic
Args:
batch: The input batch data
model: The neural network model
optimizer: The optimizer
lr_schedule: Learning rate scheduler
loss_fn: Loss function to use
Returns:
dict: Dictionary containing all computed losses with gradient detached
"""
def model_step(batch, model, loss_fn):
ypred_raw = model(batch.X, batch.mask)
ypred = unpack_predictions(ypred_raw)
ytarget = unpack_target(batch.ytarget, model)

loss_opt, losses_detached = loss_fn(ytarget, ypred, batch)
return loss_opt, losses_detached, ypred_raw, ypred, ytarget


def optimizer_step(model, loss_opt, optimizer, lr_schedule, scaler):
# Clear gradients
for param in model.parameters():
param.grad = None

# Backward pass and optimization
loss_opt.backward()
optimizer.step()
scaler.scale(loss_opt).backward()
scaler.step(optimizer)
scaler.update()
if lr_schedule:
lr_schedule.step()

return losses_detached


def eval_step(batch, model, loss_fn):
"""Single evaluation step logic
Args:
batch: The input batch data
model: The neural network model
loss_fn: Loss function to use
Returns:
tuple: (losses dict, predictions dict, targets dict)
"""
with torch.no_grad():
ypred_raw = model(batch.X, batch.mask)
ypred = unpack_predictions(ypred_raw)
ytarget = unpack_target(batch.ytarget, model)
_, losses_detached = loss_fn(ytarget, ypred, batch)

return losses_detached, ypred_raw, ypred, ytarget


def train_epoch(
rank: Union[int, str],
Expand All @@ -133,6 +102,7 @@ def train_epoch(
checkpoint_dir="",
device_type="cuda",
dtype=torch.float32,
scaler=None,
):
"""Run one training epoch
Expand Down Expand Up @@ -167,7 +137,9 @@ def train_epoch(
batch = batch.to(rank, non_blocking=True)

with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"):
loss = train_step(batch, model, optimizer, lr_schedule, mlpf_loss)
loss_opt, loss, _, _, _ = model_step(batch, model, mlpf_loss)

optimizer_step(model, loss_opt, optimizer, lr_schedule, scaler)

# Accumulate losses
for loss_name in loss:
Expand All @@ -191,14 +163,14 @@ def train_epoch(
comet_experiment.log_metric("learning_rate", lr_schedule.get_last_lr(), step=step)

# Average losses across steps
num_steps = len(train_loader)
num_steps = torch.tensor(float(len(train_loader)), device=rank, dtype=torch.float32)
if world_size > 1:
torch.distributed.all_reduce(num_steps)

for loss_name in epoch_loss:
if world_size > 1:
torch.distributed.all_reduce(epoch_loss[loss_name])
epoch_loss[loss_name] = epoch_loss[loss_name] / num_steps
epoch_loss[loss_name] = epoch_loss[loss_name].cpu().item() / num_steps.cpu().item()

if world_size > 1:
dist.barrier()
Expand Down Expand Up @@ -261,7 +233,8 @@ def eval_epoch(
set_save_attention(model, outdir, False)

with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"):
loss, ypred_raw, ypred, ytarget = eval_step(batch, model, mlpf_loss)
with torch.no_grad():
loss_opt, loss, ypred_raw, ypred, ytarget = model_step(batch, model, mlpf_loss)

# Update confusion matrices
cm_X_target += sklearn.metrics.confusion_matrix(
Expand Down Expand Up @@ -297,14 +270,14 @@ def eval_epoch(
)

# Average losses across steps
num_steps = len(valid_loader)
num_steps = torch.tensor(float(len(valid_loader)), device=rank, dtype=torch.float32)
if world_size > 1:
torch.distributed.all_reduce(num_steps)

for loss_name in epoch_loss:
if world_size > 1:
torch.distributed.all_reduce(epoch_loss[loss_name])
epoch_loss[loss_name] = epoch_loss[loss_name] / num_steps
epoch_loss[loss_name] = epoch_loss[loss_name].cpu().item() / num_steps.cpu().item()

if world_size > 1:
dist.barrier()
Expand Down Expand Up @@ -383,6 +356,8 @@ def train_all_epochs(
stale_epochs = torch.tensor(0, device=rank)
best_val_loss = float("inf")

scaler = torch.amp.GradScaler()

for epoch in range(start_epoch, num_epochs + 1):
epoch_start_time = time.time()

Expand All @@ -401,6 +376,7 @@ def train_all_epochs(
checkpoint_dir=checkpoint_dir,
device_type=device_type,
dtype=dtype,
scaler=scaler,
)
train_time = time.time() - epoch_start_time

Expand Down Expand Up @@ -430,21 +406,6 @@ def train_all_epochs(

# Handle checkpointing and early stopping on rank 0
if (rank == 0) or (rank == "cpu"):

# evaluate the model at this epoch on test datasets, make plots, track metrics
testdir_name = f"_epoch_{epoch}"
for sample in config["test_dataset"]:
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype)
plot_metrics = make_plots(outdir, sample, config["dataset"], testdir_name, config["ntest"])

# track the following jet metrics in tensorboard
for k in ["med", "iqr", "match_frac"]:
tensorboard_writer_valid.add_scalar(
"epoch/{}/jet_ratio/jet_ratio_target_to_pred_pt/{}".format(sample, k),
plot_metrics["jet_ratio"]["jet_ratio_target_to_pred_pt"][k],
epoch,
)

# Log learning rate
tensorboard_writer_train.add_scalar("epoch/learning_rate", lr_schedule.get_last_lr()[0], epoch)

Expand Down Expand Up @@ -504,6 +465,20 @@ def train_all_epochs(
tensorboard_writer_train.flush()
tensorboard_writer_valid.flush()

# evaluate the model at this epoch on test datasets, make plots, track metrics
testdir_name = f"_epoch_{epoch}"
for sample in config["enabled_test_datasets"]:
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype)
plot_metrics = make_plots(outdir, sample, config["dataset"], testdir_name, config["ntest"])

# track the following jet metrics in tensorboard
for k in ["med", "iqr", "match_frac"]:
tensorboard_writer_valid.add_scalar(
"epoch/{}/jet_ratio/jet_ratio_target_to_pred_pt/{}".format(sample, k),
plot_metrics["jet_ratio"]["jet_ratio_target_to_pred_pt"][k],
epoch,
)

# Ray training specific logging
if use_ray:
import ray
Expand Down Expand Up @@ -787,14 +762,14 @@ def run(rank, world_size, config, outdir, logfile):
testdir_name = "_best_weights"

if config["test"]:
for sample in config["test_dataset"]:
for sample in config["enabled_test_datasets"]:
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype)

# make plots only on a single machine
if (rank == 0) or (rank == "cpu"):
if config["make_plots"]:
ntest_files = -1
for sample in config["test_dataset"]:
for sample in config["enabled_test_datasets"]:
_logger.info(f"Plotting distributions for {sample}")
make_plots(outdir, sample, config["dataset"], testdir_name, ntest_files)

Expand All @@ -817,8 +792,13 @@ def override_config(config: dict, args):
for model in ["gnn_lsh", "attention", "attention", "mamba"]:
config["model"][model]["num_convs"] = args.num_convs

config["enabled_test_datasets"] = list(config["test_dataset"].keys())
if len(args.test_datasets) != 0:
config["test_dataset"] = args.test_datasets
config["enabled_test_datasets"] = args.test_datasets

config["train"] = args.train
config["test"] = args.test
config["make_plots"] = args.make_plots

return config

Expand Down
3 changes: 2 additions & 1 deletion mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def get_class_names(sample_name):
"cms_pf_single_pi0": r"single neutral pion particle gun events",
"cms_pf_single_proton": r"single proton particle gun events",
"cms_pf_single_tau": r"single tau particle gun events",
"cms_pf_single_k0": r"single K0 particle gun events",
"cms_pf_sms_t1tttt": r"sms t1tttt events",
}

Expand Down Expand Up @@ -418,7 +419,7 @@ def compute_3dmomentum_and_ratio(yvals):
}


def save_img(outfile, epoch, cp_dir=None, comet_experiment=None):
def save_img(outfile, epoch=None, cp_dir=None, comet_experiment=None):
if cp_dir:
image_path = str(cp_dir / outfile)
plt.savefig(image_path, dpi=100, bbox_inches="tight")
Expand Down
11 changes: 7 additions & 4 deletions mlpf/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def get_mem_mb(use_gpu):
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "1")

onnx_sess = rt.InferenceSession(args.model, sess_options, providers=EP_list)
# warmup

mem_onnx = get_mem_mb(use_gpu)
print("mem_onnx", mem_onnx)

# warmup
X = np.array(np.random.randn(batch_size, bin_size, num_features), getattr(np, args.input_dtype))
for i in range(10):
onnx_sess.run(None, {"Xfeat_normed": X, "mask": (X[..., 0] != 0).astype(np.float32)})
Expand All @@ -103,9 +103,12 @@ def get_mem_mb(use_gpu):

# transfer data to GPU, run model, transfer data back
t0 = time.time()
# pred_onx = onnx_sess.run(None, {"Xfeat_normed": X, "l_mask_": X[..., 0]==0})
pred_onx = onnx_sess.run(None, {"Xfeat_normed": X, "mask": (X[..., 0] != 0).astype(np.float32)})
t1 = time.time()
try:
onnx_sess.run(None, {"Xfeat_normed": X, "mask": (X[..., 0] != 0).astype(np.float32)})
t1 = time.time()
except Exception as e:
print(e)
t1 = t0
dt = (t1 - t0) / batch_size
times.append(dt)

Expand Down
Loading

0 comments on commit 6e852b0

Please sign in to comment.