Skip to content

Commit

Permalink
feat: visualize long-range affinities, if exist
Browse files Browse the repository at this point in the history
  • Loading branch information
torms3 authored and supersergiy committed Dec 13, 2023
1 parent 7e46b87 commit 568111a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
3 changes: 0 additions & 3 deletions zetta_utils/training/lightning/regimes/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from functools import reduce
from typing import Callable

import wandb
from pytorch_lightning.loggers.logger import Logger
Expand Down Expand Up @@ -59,7 +58,6 @@ def render_3d_result(data: Tensor):

def log_3d_results(
mode: str,
transforms: dict[str, Callable],
title_suffix: str = "",
**kwargs,
) -> None:
Expand All @@ -69,7 +67,6 @@ def log_3d_results(
row = []
for k, v in kwargs.items():
data = tensor_ops.crop_center(v, min_s)
data = transforms[k](data) if k in transforms else data
rendered = render_3d_result(data)
row.append(wandb.Image(rendered, caption=k))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BaseAffinityRegime(pl.LightningModule):
lr: float
amsgrad: bool = True
logits: bool = True
group: int = 3

train_log_row_interval: int = 200
val_log_row_interval: int = 25
Expand Down Expand Up @@ -64,23 +65,31 @@ def compute_loss(
self.log(f"loss/{mode}", loss.item(), on_step=True, on_epoch=True)

if log_row:
results = {
"data_in": data_in,
"target": target,
"result": torch.sigmoid(result) if self.logits else result,
}
results = {"data_in": data_in}

if torch.count_nonzero(mask) < torch.numel(mask):
results["target_mask"] = mask
target_ = target
result_ = torch.sigmoid(result) if self.logits else result

# RGB transform, if necessary
transforms = {}
# RGB transfrom, if necessary
if isinstance(self.criterion, AffinityLoss):
transforms["target"] = tensor_ops.label.seg_to_rgb
target_ = tensor_ops.seg_to_rgb(target)

# Chop into groups for visualization purpose
num_channels = target.shape[-4]
group = self.group if self.group > 0 else num_channels
for i in range(0, num_channels, group):
start = i
end = min(i + group, num_channels)
results[f"target[{start}:{end}]"] = target_[..., start:end, :, :, :]
results[f"result[{start}:{end}]"] = result_[..., start:end, :, :, :]

# Optional mask
mask_ = mask[..., start:end, :, :, :]
if torch.count_nonzero(mask_) < torch.numel(mask_):
results[f"target_mask[{start}:{end}]"] = mask_

log_3d_results(
mode,
transforms=transforms,
title_suffix=sample_name,
**results,
)
Expand Down

0 comments on commit 568111a

Please sign in to comment.