From 9ddf8818fede27a119f5912ba17ecbfcc1220a40 Mon Sep 17 00:00:00 2001 From: supersergiy Date: Wed, 29 Nov 2023 16:41:00 -0800 Subject: [PATCH] Alignment Subrepo --- .github/workflows/docs_deployment.yaml | 16 +- .github/workflows/testing.yaml | 65 ++- .github/workflows/testing_integration.yaml | 18 +- .gitmodules | 3 + zetta_utils/__init__.py | 13 +- zetta_utils/alignment | 1 + zetta_utils/alignment/__init__.py | 9 - zetta_utils/alignment/aced_relaxation.py | 471 ------------------ zetta_utils/alignment/base_coarsener.py | 82 --- zetta_utils/alignment/base_encoder.py | 55 -- zetta_utils/alignment/defect_detector.py | 80 --- zetta_utils/alignment/encoding_coarsener.py | 41 -- zetta_utils/alignment/field.py | 199 -------- .../alignment/misalignment_detector.py | 54 -- zetta_utils/alignment/online_finetuner.py | 110 ---- zetta_utils/alignment/resin_detector.py | 114 ----- 16 files changed, 95 insertions(+), 1236 deletions(-) create mode 100644 .gitmodules create mode 160000 zetta_utils/alignment delete mode 100644 zetta_utils/alignment/__init__.py delete mode 100644 zetta_utils/alignment/aced_relaxation.py delete mode 100644 zetta_utils/alignment/base_coarsener.py delete mode 100644 zetta_utils/alignment/base_encoder.py delete mode 100644 zetta_utils/alignment/defect_detector.py delete mode 100644 zetta_utils/alignment/encoding_coarsener.py delete mode 100644 zetta_utils/alignment/field.py delete mode 100644 zetta_utils/alignment/misalignment_detector.py delete mode 100644 zetta_utils/alignment/online_finetuner.py delete mode 100644 zetta_utils/alignment/resin_detector.py diff --git a/.github/workflows/docs_deployment.yaml b/.github/workflows/docs_deployment.yaml index 425b687f0..942a575af 100644 --- a/.github/workflows/docs_deployment.yaml +++ b/.github/workflows/docs_deployment.yaml @@ -19,12 +19,22 @@ jobs: url: ${{ steps.deployment.outputs.page_url }} runs-on: ubuntu-latest steps: - - name: Checkout - uses: actions/checkout@v3 + - name: Get token from Github App + uses: actions/create-github-app-token@v1 + id: app_token + with: + app-id: ${{ secrets.APP_ID }} + private-key: ${{ secrets.APP_PEM }} + # owner is required, otherwise the creds will fail the checkout step + owner: ${{ github.repository_owner }} + + - name: Checkout from GitHub + uses: actions/checkout@v4 with: lfs: 'false' - submodules: 'recursive' + submodules: true ssh-key: ${{ secrets.git_ssh_key }} + token: ${{ steps.app_token.outputs.token }} - name: Setup Python uses: actions/setup-python@v4 with: diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index 836a8c5d1..6cb3b71c4 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -21,12 +21,23 @@ jobs: - "3.10" runs-on: ${{ matrix.os }} steps: - - name: Checkout - uses: actions/checkout@v3 + - name: Get token from Github App + uses: actions/create-github-app-token@v1 + id: app_token + with: + app-id: ${{ secrets.APP_ID }} + private-key: ${{ secrets.APP_PEM }} + # owner is required, otherwise the creds will fail the checkout step + owner: ${{ github.repository_owner }} + + - name: Checkout from GitHub + uses: actions/checkout@v4 with: lfs: 'false' - submodules: 'recursive' + submodules: true ssh-key: ${{ secrets.git_ssh_key }} + token: ${{ steps.app_token.outputs.token }} + - name: Get changed files uses: dorny/paths-filter@v2 id: filter @@ -75,12 +86,22 @@ jobs: - "3.10" runs-on: ${{ matrix.os }} steps: - - name: Checkout - uses: actions/checkout@v3 + - name: Get token from Github App + uses: actions/create-github-app-token@v1 + id: app_token + with: + app-id: ${{ secrets.APP_ID }} + private-key: ${{ secrets.APP_PEM }} + # owner is required, otherwise the creds will fail the checkout step + owner: ${{ github.repository_owner }} + + - name: Checkout from GitHub + uses: actions/checkout@v4 with: lfs: 'false' - submodules: 'recursive' + submodules: true ssh-key: ${{ secrets.git_ssh_key }} + token: ${{ steps.app_token.outputs.token }} - name: Get changed files uses: dorny/paths-filter@v2 id: filter @@ -116,12 +137,22 @@ jobs: - "3.10" runs-on: ${{ matrix.os }} steps: - - name: Checkout - uses: actions/checkout@v3 + - name: Get token from Github App + uses: actions/create-github-app-token@v1 + id: app_token + with: + app-id: ${{ secrets.APP_ID }} + private-key: ${{ secrets.APP_PEM }} + # owner is required, otherwise the creds will fail the checkout step + owner: ${{ github.repository_owner }} + + - name: Checkout from GitHub + uses: actions/checkout@v4 with: lfs: 'false' - submodules: 'recursive' + submodules: true ssh-key: ${{ secrets.git_ssh_key }} + token: ${{ steps.app_token.outputs.token }} - name: Get changed files uses: dorny/paths-filter@v2 id: filter @@ -155,12 +186,22 @@ jobs: - "3.10" runs-on: ${{ matrix.os }} steps: - - name: Checkout - uses: actions/checkout@v3 + - name: Get token from Github App + uses: actions/create-github-app-token@v1 + id: app_token + with: + app-id: ${{ secrets.APP_ID }} + private-key: ${{ secrets.APP_PEM }} + # owner is required, otherwise the creds will fail the checkout step + owner: ${{ github.repository_owner }} + + - name: Checkout from GitHub + uses: actions/checkout@v4 with: lfs: 'false' - submodules: 'recursive' + submodules: true ssh-key: ${{ secrets.git_ssh_key }} + token: ${{ steps.app_token.outputs.token }} - name: Get changed files uses: dorny/paths-filter@v2 id: filter diff --git a/.github/workflows/testing_integration.yaml b/.github/workflows/testing_integration.yaml index 47a0fa544..e374a7966 100644 --- a/.github/workflows/testing_integration.yaml +++ b/.github/workflows/testing_integration.yaml @@ -17,12 +17,22 @@ jobs: - "3.10" runs-on: ${{ matrix.os }} steps: - - name: Checkout - uses: actions/checkout@v3 + - name: Get token from Github App + uses: actions/create-github-app-token@v1 + id: app_token with: - lfs: 'true' - submodules: 'recursive' + app-id: ${{ secrets.APP_ID }} + private-key: ${{ secrets.APP_PEM }} + # owner is required, otherwise the creds will fail the checkout step + owner: ${{ github.repository_owner }} + + - name: Checkout from GitHub + uses: actions/checkout@v4 + with: + lfs: 'false' + submodules: true ssh-key: ${{ secrets.git_ssh_key }} + token: ${{ steps.app_token.outputs.token }} - name: Get changed files uses: dorny/paths-filter@v2 id: filter diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..151433cc5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "zetta_utils/alignment"] + path = zetta_utils/alignment + url = git@github.com:ZettaAI/alignment.git diff --git a/zetta_utils/__init__.py b/zetta_utils/__init__.py index 6eeb91ab2..ce2022917 100644 --- a/zetta_utils/__init__.py +++ b/zetta_utils/__init__.py @@ -31,9 +31,15 @@ def try_load_train_inference(): # pragma: no cover ... +def try_load_submodules(): # pragma: no cover + try: + from . import alignment + except ImportError: + ... + + def load_inference_modules(): from . import ( - alignment, augmentations, convnet, mazepa, @@ -47,10 +53,11 @@ def load_inference_modules(): from .layer.volumetric import cloudvol from .message_queues import sqs + try_load_submodules() + def load_training_modules(): from . import ( - alignment, augmentations, convnet, mazepa, @@ -62,5 +69,7 @@ def load_training_modules(): from .layer import volumetric from .layer.volumetric import cloudvol + try_load_submodules() + try_load_train_inference() diff --git a/zetta_utils/alignment b/zetta_utils/alignment new file mode 160000 index 000000000..057e62be5 --- /dev/null +++ b/zetta_utils/alignment @@ -0,0 +1 @@ +Subproject commit 057e62be5f2c4ca2b0591a0dfaedea585a3da238 diff --git a/zetta_utils/alignment/__init__.py b/zetta_utils/alignment/__init__.py deleted file mode 100644 index f3492368c..000000000 --- a/zetta_utils/alignment/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from . import field -from . import online_finetuner -from . import aced_relaxation -from .base_encoder import BaseEncoder -from .base_coarsener import BaseCoarsener -from .encoding_coarsener import EncodingCoarsener -from .misalignment_detector import MisalignmentDetector -from .defect_detector import DefectDetector -from .resin_detector import ResinDetector diff --git a/zetta_utils/alignment/aced_relaxation.py b/zetta_utils/alignment/aced_relaxation.py deleted file mode 100644 index 2feb68eab..000000000 --- a/zetta_utils/alignment/aced_relaxation.py +++ /dev/null @@ -1,471 +0,0 @@ -# pylint: disable=too-many-locals -from __future__ import annotations - -from typing import Dict, List, Literal, Optional - -import attrs -import einops - -# import metroem -import torch -import torchfields # pylint: disable=unused-import # monkeypatch - -from zetta_utils import builder, log - -from .field import get_rigidity_map_zcxy, invert_field - -logger = log.get_logger("zetta_utils") - - -def compute_aced_loss_new( - pfields_raw: Dict[int, torch.Tensor], - afields: List[torch.Tensor], - match_offsets: List[torch.Tensor], - rigidity_weight: float, - rigidity_masks: torch.Tensor, - max_dist: int, - min_rigidity_multiplier: float, -) -> torch.Tensor: - intra_loss = 0 - inter_loss = 0 - afields_cat = torch.cat(afields) - match_offsets_cat = torch.stack(match_offsets) - - match_offsets_warped = { - offset: afields_cat((match_offsets_cat == offset).float()) > 0.7 # type: ignore - for offset in range(1, max_dist + 1) - } - inter_loss = 0 - for offset in range(1, max_dist + 1): - inter_expectation = afields_cat[:-offset](pfields_raw[offset][offset:]) # type: ignore - inter_loss_map = inter_expectation - afields_cat[offset:] - - inter_loss_map_mask = match_offsets_warped[offset].squeeze()[offset:] - this_inter_loss = (inter_loss_map ** 2).sum(1)[..., inter_loss_map_mask].sum() - inter_loss += this_inter_loss - - intra_loss_map = get_rigidity_map_zcxy(afields_cat.pixels()) # type: ignore - with torch.no_grad(): - rigidity_masks_warped = afields_cat(rigidity_masks.float()) # type: ignore - rigidity_masks_warped[ - rigidity_masks_warped < min_rigidity_multiplier - ] = min_rigidity_multiplier - intra_loss = (intra_loss_map * rigidity_masks_warped.squeeze()).sum() - loss = inter_loss + rigidity_weight * intra_loss / ( - afields[0].shape[-1] * afields[0].shape[-1] / 4 - ) - - return loss # type: ignore - - -def _get_opt_range(fix: Literal["first", "last", "both"] | None, num_sections: int): - if fix is None: - opt_range = range(num_sections) - elif fix == "first": - opt_range = range(1, num_sections) - elif fix == "last": - opt_range = range(num_sections - 1) - else: - assert fix == "both" - opt_range = range(1, num_sections - 1) - return opt_range - - -@builder.register("perform_aced_relaxation") -def perform_aced_relaxation( # pylint: disable=too-many-branches - match_offsets: torch.Tensor, - pfields: dict[str, torch.Tensor], - rigidity_masks: torch.Tensor | None = None, - first_section_fix_field: torch.Tensor | None = None, - last_section_fix_field: torch.Tensor | None = None, - min_rigidity_multiplier: float = 0.0, - num_iter=100, - lr=0.3, - rigidity_weight=10.0, - fix: Optional[Literal["first", "last", "both"]] = "first", - max_dist: int = 2, - grad_clip: float | None = None, -) -> torch.Tensor: - assert "-1" in pfields - - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - max_displacement = max([field.abs().max().item() for field in pfields.values()]) - - if (match_offsets != 0).sum() == 0 or max_displacement < 0.01: - return torch.zeros_like(pfields["-1"]) - - match_offsets_zcxy = einops.rearrange(match_offsets, "C X Y Z -> Z C X Y").to(device) - - if rigidity_masks is not None: - rigidity_masks_zcxy = einops.rearrange(rigidity_masks, "C X Y Z -> Z C X Y").to(device) - else: - rigidity_masks_zcxy = torch.ones_like(match_offsets_zcxy) - - num_sections = match_offsets_zcxy.shape[0] - assert num_sections > 1, "Can't relax blocks with just one section" - - pfields_raw: Dict[int, torch.Tensor] = {} - - for offset_str, field in pfields.items(): - offset = -int(offset_str) - pfields_raw[offset] = ( - einops.rearrange(field, "C X Y Z -> Z C X Y") - .field() # type: ignore - .to(device) - .from_pixels() - ) - - if first_section_fix_field is not None: - assert fix in ["first", "both"] - - first_section_fix_field_zcxy = ( - einops.rearrange(first_section_fix_field, "C X Y Z -> Z C X Y") - .field() # type: ignore - .to(device) - .from_pixels() - ) - for offset in range(1, max_dist + 1): - pfields_raw[offset][offset] = first_section_fix_field_zcxy(pfields_raw[offset][offset]) - - if last_section_fix_field is not None: - assert fix in ["last", "both"] - - last_section_fix_field_inv = invert_field(last_section_fix_field.to(device)) - last_section_fix_field_inv_zcxy = ( - einops.rearrange(last_section_fix_field_inv, "C X Y Z -> Z C X Y") - .field() # type: ignore - .to(device) - .from_pixels() - ) - - for offset in range(1, max_dist + 1): - pfields_raw[offset][-1] = pfields_raw[offset][-1]( - last_section_fix_field_inv_zcxy # type: ignore - ) - - afields = [ - torch.zeros((1, 2, match_offsets_zcxy.shape[2], match_offsets_zcxy.shape[3])) - .to(device) - .field() # type: ignore - .from_pixels() - for _ in range(num_sections) - ] - - opt_range = _get_opt_range(fix=fix, num_sections=num_sections) - for i in opt_range: - afields[i].requires_grad = True - - optimizer = torch.optim.Adam( - [afields[i] for i in opt_range], - lr=lr, - ) - - with torchfields.set_identity_mapping_cache(True, clear_cache=True): - for i in range(num_iter): - loss_new = compute_aced_loss_new( - pfields_raw=pfields_raw, - afields=afields, - rigidity_masks=rigidity_masks_zcxy, - match_offsets=[match_offsets_zcxy[i] for i in range(num_sections)], - rigidity_weight=rigidity_weight, - max_dist=max_dist, - min_rigidity_multiplier=min_rigidity_multiplier, - ) - loss = loss_new - if i % 100 == 0: - logger.info(f"Iter {i} loss: {loss}") - optimizer.zero_grad() - loss.backward() - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_(afields, max_norm=grad_clip) - optimizer.step() - - result_xy = torch.cat(afields, 0).pixels() # type: ignore - result = einops.rearrange(result_xy, "Z C X Y -> C X Y Z") - return result - - -def get_aced_match_offsets_naive( - non_tissue: torch.Tensor, - misalignment_mask_zm1: torch.Tensor, - misalignment_mask_zm2: Optional[torch.Tensor] = None, - misalignment_mask_zm3: Optional[torch.Tensor] = None, -) -> torch.Tensor: - - match_offsets = torch.ones_like(non_tissue, dtype=torch.int) * -1 - match_offsets[non_tissue] = 0 - - misalignment_mask_map = { - 1: misalignment_mask_zm1, - 2: misalignment_mask_zm2, - 3: misalignment_mask_zm3, - } - - for offset in sorted(misalignment_mask_map.keys()): - unmatched_locations = match_offsets == -1 - if unmatched_locations.sum() == 0: - break - if misalignment_mask_map[offset] is not None: - current_match_locations = misalignment_mask_map[offset] == 0 - match_offsets[unmatched_locations * current_match_locations] = offset - - match_offsets[match_offsets == -1] = 0 - result = match_offsets.byte() - return result - - -def get_aced_match_offsets( - tissue_mask: torch.Tensor, - misalignment_masks: dict[str, torch.Tensor], - pairwise_fields: dict[str, torch.Tensor], - pairwise_fields_inv: dict[str, torch.Tensor], - max_dist: int, -) -> dict[str, torch.Tensor]: - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - with torchfields.set_identity_mapping_cache(True, clear_cache=True): - tissue_mask_zcxy = einops.rearrange(tissue_mask, "1 X Y Z -> Z 1 X Y").to(device) - misalignment_masks_zcxy = { - k: einops.rearrange(v, "1 X Y Z -> Z 1 X Y").to(device) - for k, v in misalignment_masks.items() - } - pairwise_fields_zcxy = { - k: einops.rearrange(v, "C X Y Z -> Z C X Y") - .field() # type: ignore - .from_pixels() - .to(device) - for k, v in pairwise_fields.items() - } - pairwise_fields_inv_zcxy = { - k: einops.rearrange(v, "C X Y Z -> Z C X Y") - .field() # type: ignore - .from_pixels() - .to(device) - for k, v in pairwise_fields_inv.items() - } - - fwd_outcome = _perform_match_fwd_pass( - tissue_mask_zcxy=tissue_mask_zcxy, - misalignment_masks_zcxy=misalignment_masks_zcxy, - # pairwise_fields_zcxy=pairwise_fields_zcxy, - pairwise_fields_inv_zcxy=pairwise_fields_inv_zcxy, - max_dist=max_dist, - ) - sector_length_after_zcxy = _perform_match_bwd_pass( - match_offsets_inv_zcxy=fwd_outcome.match_offsets_inv_zcxy, - pairwise_fields_zcxy=pairwise_fields_zcxy, - max_dist=max_dist, - ) - img_mask_zcxy, aff_mask_zcxy = _get_masks( - sector_length_before_zcxy=fwd_outcome.sector_length_before_zcxy, - sector_length_after_zcxy=sector_length_after_zcxy, - # match_offsets_zcxy=fwd_outcome.match_offsets_zcxy, - # pairwise_fields_inv_zcxy=pairwise_fields_inv_zcxy, - max_dist=max_dist, - tissue_mask_zcxy=tissue_mask_zcxy, - ) - result = { - "match_offsets": fwd_outcome.match_offsets_zcxy, - "img_mask": img_mask_zcxy, - "aff_mask": aff_mask_zcxy, - "sector_length_after": sector_length_after_zcxy, - "sector_length_before": fwd_outcome.sector_length_before_zcxy, - } - result = {k: einops.rearrange(v, "Z C X Y -> C X Y Z") for k, v in result.items()} - return result - - -@attrs.mutable -class _FwdPassOutcome: - sector_length_before_zcxy: torch.Tensor - match_offsets_zcxy: torch.Tensor - match_offsets_inv_zcxy: torch.Tensor - - -def _perform_match_fwd_pass( - tissue_mask_zcxy: torch.Tensor, - misalignment_masks_zcxy: dict[str, torch.Tensor], - pairwise_fields_inv_zcxy: dict[str, torch.Tensor], - max_dist: int, -) -> _FwdPassOutcome: - num_sections = tissue_mask_zcxy.shape[0] - - sector_length_before_zcxy = torch.zeros_like(tissue_mask_zcxy).int() - match_offsets_zcxy = torch.zeros_like(tissue_mask_zcxy).int() - match_offsets_inv_zcxy = torch.zeros_like(tissue_mask_zcxy).int() - - for i in range(1, num_sections): - offset_scores = torch.zeros( - (max_dist, 1, tissue_mask_zcxy.shape[-2], tissue_mask_zcxy.shape[-1]), - dtype=torch.float32, - device=tissue_mask_zcxy.device, - ) - - offset_sector_lengths = torch.zeros( - (max_dist, 1, tissue_mask_zcxy.shape[-2], tissue_mask_zcxy.shape[-1]), - dtype=torch.int32, - device=tissue_mask_zcxy.device, - ) - - for offset in range(1, max_dist + 1): - j = i - offset - if j < 0: - break - - this_pairwise_field_inv = pairwise_fields_inv_zcxy[str(-offset)][i : i + 1] - - tgt_tissue_mask = this_pairwise_field_inv.sample( # type: ignore - tissue_mask_zcxy[j].float(), - mode="nearest", - ).int() - - this_tissue_mask = tissue_mask_zcxy[i] * tgt_tissue_mask - - this_misalignment_mask = misalignment_masks_zcxy[str(-offset)][i] - - offset_sector_lengths[offset - 1] = ( - this_pairwise_field_inv.sample( # type: ignore - sector_length_before_zcxy[j].float(), - mode="nearest", - ).int() - + 1 - ) - - # Offset lengths score prioritizes longer match chain - # misalignmened and non-tissue correspondences are not prioritized by it - offset_sector_lengths[offset - 1][this_tissue_mask == 0] = 0 - offset_sector_lengths[offset - 1][this_misalignment_mask != 0] = 0 - offset_sector_length_scores = offset_sector_lengths[offset - 1] / 1e5 - assert offset_sector_length_scores.max() <= 1.0 - - offset_scores[offset - 1] = this_tissue_mask * 100 - offset_scores[offset - 1] += (misalignment_masks_zcxy[str(-offset)][i] == 0) * 10 - offset_scores[offset - 1] += offset_sector_length_scores - offset_scores[offset - 1] += (max_dist - offset) / 1e7 - - chosen_offset_scores, chosen_offsets = offset_scores.max(0) - passable_choices = chosen_offset_scores >= 110 - match_offsets_zcxy[i][passable_choices] = chosen_offsets[passable_choices].int() + 1 - # match_offsets_zcxy[i] = this_tissue_mask - - # sector_length_before_zcxy[i] = offset_sector_lengths[chosen_offsets] - # TODO: how do vectorize this? - for choice in range(0, max_dist): - this_match_locations = chosen_offsets == choice - sector_length_before_zcxy[i][this_match_locations] = offset_sector_lengths[choice][ - this_match_locations - ] - - for offset in range(1, max_dist + 1): - j = i - offset - this_offset_matches = match_offsets_zcxy[i] == offset - - # Discard non-aligned matches for bwd pass - this_offset_matches[chosen_offset_scores < 110] = 0 - if this_offset_matches.sum() > 0: - this_inv_field = pairwise_fields_inv_zcxy[str(-offset)][i : i + 1] - this_offset_matches_inv = this_inv_field.sample( # type: ignore - this_offset_matches.float(), mode="nearest" - ).int() - this_offset_matches_inv[tissue_mask_zcxy[j] == 0] = 0 - match_offsets_inv_zcxy[j][this_offset_matches_inv != 0] = offset - - return _FwdPassOutcome( - sector_length_before_zcxy=sector_length_before_zcxy, - match_offsets_zcxy=match_offsets_zcxy, - match_offsets_inv_zcxy=match_offsets_inv_zcxy, - ) - - -def _get_masks( - sector_length_before_zcxy: torch.Tensor, - sector_length_after_zcxy: torch.Tensor, - # pairwise_fields_inv_zcxy: dict[str, torch.Tensor], - # match_offsets_zcxy: torch.Tensor, - tissue_mask_zcxy: torch.Tensor, - max_dist: int, -) -> tuple[torch.Tensor, torch.Tensor]: - # num_sections = sector_length_before_zcxy.shape[0] - - # img_mask_zcxy = (sector_length_before_zcxy + sector_length_after_zcxy) < max_dist - # aff_mask_zcxy = (sector_length_before_zcxy == 0) * (img_mask_zcxy == 0) - - img_mask_zcxy = (sector_length_before_zcxy + sector_length_after_zcxy) < max_dist - - aff_mask_zcxy = (sector_length_before_zcxy == 0) * (sector_length_after_zcxy >= max_dist) - # TODO: fix - # aff_mask_zcxy[1:] += (sector_length_after_zcxy[:-1] == 0) * ( - # sector_length_before_zcxy[:-1] >= max_dist - # ) - # TODO: Decide whether we want this - # for i in range(1, num_sections): - # for offset in range(1, max_dist + 1): - - # j = i - offset - # this_offset_matches = match_offsets_zcxy[i] == offset - - # if this_offset_matches.sum() > 0: - # this_inv_field = pairwise_fields_inv_zcxy[str(-offset)][i : i + 1] - # this_sector_length_after_from_j = this_inv_field.sample( # type: ignore - # sector_length_after_zcxy[j].float(), mode="nearest" - # ).int() - # this_sector_length_before_from_j = this_inv_field.sample( # type: ignore - # sector_length_before_zcxy[j].float(), mode="nearest" - # ).int() - - # back_connected_locations = sector_length_before_zcxy[i] == ( - # this_sector_length_before_from_j + 1 - # ) - # mid_connected_locations = sector_length_after_zcxy[i] == ( - # this_sector_length_after_from_j - 1 - # ) - # dangling_tail_locations = ( - # back_connected_locations * (mid_connected_locations == 0) * - # this_offset_matches - # ) - - # img_mask_zcxy[i][dangling_tail_locations] = True - # aff_mask_zcxy[i][dangling_tail_locations] = False - # if i + i < num_sections: - # aff_mask_zcxy[i + 1][dangling_tail_locations] = False - - img_mask_zcxy[0] = False - aff_mask_zcxy[0] = False - aff_mask_zcxy[-1][img_mask_zcxy[-1] != 0] = 1 - img_mask_zcxy[-1] = 0 - img_mask_zcxy[tissue_mask_zcxy == 0] = 1 - return img_mask_zcxy, aff_mask_zcxy - - -def _perform_match_bwd_pass( - match_offsets_inv_zcxy: torch.Tensor, - pairwise_fields_zcxy: dict[str, torch.Tensor], - max_dist: int, -): - sector_length_after_zcxy = torch.zeros_like(match_offsets_inv_zcxy) - num_sections = match_offsets_inv_zcxy.shape[0] - for i in range(num_sections - 1, -1, -1): - for offset in range(1, max_dist + 1): - j = i + offset - if j >= num_sections: - continue - - this_pairwise_field = pairwise_fields_zcxy[str(-offset)][j : j + 1] - - this_offset_sector_length = this_pairwise_field.sample( # type: ignore - sector_length_after_zcxy[j].float(), mode="nearest" - ).int() - this_offset_sector_length[match_offsets_inv_zcxy[i] != offset] = 0 - this_offset_sector_length[match_offsets_inv_zcxy[i] == offset] += 1 - - sector_length_after_zcxy[i] = torch.max( - sector_length_after_zcxy[i], this_offset_sector_length - ) - return sector_length_after_zcxy diff --git a/zetta_utils/alignment/base_coarsener.py b/zetta_utils/alignment/base_coarsener.py deleted file mode 100644 index 4dd495619..000000000 --- a/zetta_utils/alignment/base_coarsener.py +++ /dev/null @@ -1,82 +0,0 @@ -import attrs -import einops -import torch -from typeguard import typechecked - -from zetta_utils import builder, convnet - - -@builder.register("BaseCoarsener") -@typechecked -@attrs.mutable -class BaseCoarsener: - # Input int8 [ -127 .. 127] or uint8 [0 .. 255] - # Output int8 Encodings [-127 .. 127] - - # Don't create the model during initialization for efficient serialization - model_path: str - abs_val_thr: float = 0.005 - ds_factor: int = 1 - output_channels: int = 1 - tile_pad_in: int = 128 - tile_size: int = 1024 - - def __call__(self, src: torch.Tensor) -> torch.Tensor: - with torch.no_grad(): - device = "cuda" if torch.cuda.is_available() else "cpu" - - # load model during the call _with caching_ - model = convnet.utils.load_model(self.model_path, device=device, use_cache=True) - - # uint8 raw images or int8 encodings - if src.dtype == torch.int8: - data_in = src.float() / 127.0 # [-1.0 .. 1.0] - elif src.dtype == torch.uint8: - data_in = src.float() / 255.0 # [ 0.0 .. 1.0] - else: - raise ValueError(f"Unsupported src dtype: {src.dtype}") - - data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y").to(device) - result = torch.zeros( - data_in.shape[0], - self.output_channels, - data_in.shape[-2] // self.ds_factor, - data_in.shape[-1] // self.ds_factor, - dtype=torch.float32, - layout=data_in.layout, - device=data_in.device - ) - tile_pad_out = self.tile_pad_in // self.ds_factor - - for x in range(self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size): - x_start = x - self.tile_pad_in - x_end = x + self.tile_size + self.tile_pad_in - for y in range( - self.tile_pad_in, data_in.shape[-1] - self.tile_pad_in, self.tile_size - ): - y_start = y - self.tile_pad_in - y_end = y + self.tile_size + self.tile_pad_in - tile = data_in[:, :, x_start:x_end, y_start:y_end] - if (tile != 0).sum() > 0.0: - with torch.autocast(device_type=device): - tile_result = model(tile) - if tile_pad_out > 0: - tile_result = tile_result[ - :, :, tile_pad_out:-tile_pad_out, tile_pad_out:-tile_pad_out - ] - - result[ - :, - :, - x // self.ds_factor : x // self.ds_factor + tile_result.shape[-2], - y // self.ds_factor : y // self.ds_factor + tile_result.shape[-1], - ] = tile_result - - result = einops.rearrange(result, "Z C X Y -> C X Y Z") - - # Final layer assumed to be tanh - assert result.abs().max() <= 1 - result[result.abs() < self.abs_val_thr] = 0 - result = 127.0 * result - - return result.round().type(torch.int8).clamp(-127, 127) diff --git a/zetta_utils/alignment/base_encoder.py b/zetta_utils/alignment/base_encoder.py deleted file mode 100644 index 31231204d..000000000 --- a/zetta_utils/alignment/base_encoder.py +++ /dev/null @@ -1,55 +0,0 @@ -import attrs -import einops -import torch -from typeguard import typechecked - -from zetta_utils import builder, convnet - - -@builder.register("BaseEncoder") -@typechecked -@attrs.mutable -class BaseEncoder: - # Input uint8 [ 0 .. 255] - # Output int8 Encodings [-127 .. 127] - - # Don't create the model during initialization for efficient serialization - model_path: str - abs_val_thr: float = 0.005 - uint_output: bool = False - - def __call__(self, src: torch.Tensor) -> torch.Tensor: - if (src != 0).sum() == 0: - result = torch.zeros_like(src).float() - else: - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # load model during the call _with caching_ - model = convnet.utils.load_model(self.model_path, device=device, use_cache=True) - - if src.dtype == torch.uint8: - data_in = src.float() / 255.0 # [0.0 .. 1.0] - else: - raise ValueError(f"Unsupported src dtype: {src.dtype}") - - data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y") - with torch.autocast(device_type=device): - result = model(data_in.to(device)) - result = einops.rearrange(result, "Z C X Y -> C X Y Z") - - # Final layer assumed to be tanh - assert result.abs().max() <= 1 - result[result.abs() < self.abs_val_thr] = 0 - if self.uint_output: - # FOR LEGACY MODELS. to be removed - result += 1 - result = 127.0 * result - - if self.uint_output: - # FOR LEGACY MODELS. to be removed - return result.type(torch.uint8) - else: - return result.round().type(torch.int8).clamp(-127, 127) diff --git a/zetta_utils/alignment/defect_detector.py b/zetta_utils/alignment/defect_detector.py deleted file mode 100644 index 675781db0..000000000 --- a/zetta_utils/alignment/defect_detector.py +++ /dev/null @@ -1,80 +0,0 @@ -import attrs -import einops -import torch -from typeguard import typechecked - -from zetta_utils import builder, convnet - - -@builder.register("DefectDetector") -@typechecked -@attrs.mutable -class DefectDetector: - # Input uint8 [ 0 .. 255] - # Output uint8 Prediction [0 .. 255] - - # Don't create the model during initialization for efficient serialization - model_path: str - tile_pad_in: int = 32 - tile_size: int = 448 - - def __call__(self, src: torch.Tensor) -> torch.Tensor: - if (src != 0).sum() == 0: - result = torch.zeros_like(src).float() - else: - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # load model during the call _with caching_ - model = convnet.utils.load_model(self.model_path, device=device, use_cache=True) - - if src.dtype == torch.uint8: - data_in = src.float() / 255.0 # [0.0 .. 1.0] - else: - raise ValueError(f"Unsupported src dtype: {src.dtype}") - - data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y") - data_in = data_in.to(device=device) - with torch.no_grad(): - result = torch.zeros_like( - data_in[ - ..., - : data_in.shape[-2], - : data_in.shape[-1], - ] - ).float() - - tile_pad_out = self.tile_pad_in - - for x in range( - self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size - ): - x_start = x - self.tile_pad_in - x_end = x + self.tile_size + self.tile_pad_in - for y in range( - self.tile_pad_in, data_in.shape[-1] - self.tile_pad_in, self.tile_size - ): - y_start = y - self.tile_pad_in - y_end = y + self.tile_size + self.tile_pad_in - tile = data_in[:, :, x_start:x_end, y_start:y_end] - if (tile != 0).sum() > 0.0: - tile_result = model(tile) - if tile_pad_out > 0: - tile_result = tile_result[ - :, :, tile_pad_out:-tile_pad_out, tile_pad_out:-tile_pad_out - ] - - result[ - :, - :, - x : x + tile_result.shape[-2], - y : y + tile_result.shape[-1], - ] = tile_result - - result = einops.rearrange(result, "Z C X Y -> C X Y Z") - result = 255.0 * torch.sigmoid(result) - result[src == 0.0] = 0.0 - - return result.round().clamp(0, 255).type(torch.uint8) diff --git a/zetta_utils/alignment/encoding_coarsener.py b/zetta_utils/alignment/encoding_coarsener.py deleted file mode 100644 index 139a7bd5d..000000000 --- a/zetta_utils/alignment/encoding_coarsener.py +++ /dev/null @@ -1,41 +0,0 @@ -import attrs -import einops -import torch -from typeguard import typechecked - -from zetta_utils import builder, convnet - - -@builder.register("EncodingCoarsener") -@typechecked -@attrs.mutable -class EncodingCoarsener: - # Input int8 [ -127 .. 127] - # Output int8 Encodings [-127 .. 127] - - # Don't create the model during initialization for efficient serialization - model_path: str - abs_val_thr: float = 0.005 - - def __call__(self, src: torch.Tensor) -> torch.Tensor: - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # load model during the call _with caching_ - model = convnet.utils.load_model(self.model_path, device=device, use_cache=True) - if src.dtype == torch.int8: - data_in = src.float() / 127.0 - else: - raise ValueError(f"Unsupported src dtype: {src.dtype}") - - data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y") - result = model(data_in.to(device)) - result = einops.rearrange(result, "Z C X Y -> C X Y Z") - - # Final layer assumed to be tanh - assert result.abs().max() <= 1 - result[result.abs() < self.abs_val_thr] = 0 - result = 127.0 * (result) - return result.round().type(torch.int8).clamp(-127, 127) diff --git a/zetta_utils/alignment/field.py b/zetta_utils/alignment/field.py deleted file mode 100644 index d3373cbec..000000000 --- a/zetta_utils/alignment/field.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import Literal, Tuple - -import einops -import torch -import torchfields # pylint: disable=unused-import -from torch.optim.lr_scheduler import ReduceLROnPlateau - -from zetta_utils import builder -from zetta_utils.augmentations.tensor import rand_perlin_2d_octaves - - -def profile_field2d_percentile( - field: torch.Tensor, # C, X, Y, Z - high: float = 25, - low: float = 75, -) -> Tuple[int, int]: - - nonzero_field_mask = (field[0] != 0) & (field[1] != 0) - - nonzero_field = field[..., nonzero_field_mask].squeeze() - - if nonzero_field.sum() == 0 or len(nonzero_field.shape) == 1: - result = (0, 0) - else: - low_l = percentile(nonzero_field, low) - high_l = percentile(nonzero_field, high) - mid = 0.5 * (low_l + high_l) - result = (int(mid[0]), int(mid[1])) - - return result - - -def percentile(field: torch.Tensor, q: float): - # https://gist.github.com/spezold/42a451682422beb42bc43ad0c0967a30 - """ - Return the ``q``-th percentile of the flattened input tensor's data. - CAUTION: - * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. - * Values are not interpolated, which corresponds to - ``numpy.percentile(..., interpolation="nearest")``. - :param field: Input tensor. - :param q: Percentile to compute, which must be between 0 and 100 inclusive. - :return: Resulting value (scalar). - """ - # Note that ``kthvalue()`` works one-based, i.e. the first sorted value - # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, - # so that ``round()`` returns an integer, even if q is a np.float32. - k = 1 + round(0.01 * float(q) * (field.shape[1] - 1)) - result = field.kthvalue(k, dim=1).values - return result - - -@builder.register("invert_field") -def invert_field(src: torch.Tensor, mode: Literal["opti", "torchfields"] = "opti") -> torch.Tensor: - if src.abs().sum() == 0: - return src - - if mode == "opti": - result = invert_field_opti(src) - else: - src_zcxy = einops.rearrange(src, "C X Y Z -> Z C X Y").cuda().field_() # type: ignore - with torchfields.set_identity_mapping_cache(True): - result_zcxy = (~(src_zcxy.from_pixels())).pixels() - result = einops.rearrange(result_zcxy, "Z C X Y -> C X Y Z") - return result - - -def invert_field_opti(src: torch.Tensor, num_iter: int = 200, lr: float = 1e-3) -> torch.Tensor: - if src.abs().sum() == 0: - return src - - src_zcxy = ( - einops.rearrange(src, "C X Y Z -> Z C X Y").cuda().field_().from_pixels() # type: ignore - ) - - with torchfields.set_identity_mapping_cache(True): - inverse_zcxy = (-src_zcxy).clone() - inverse_zcxy.requires_grad = True - - optimizer = torch.optim.Adam([inverse_zcxy], lr=lr) - scheduler = ReduceLROnPlateau(optimizer, "min", patience=5, min_lr=1e-5) - for _ in range(num_iter): - loss = inverse_zcxy(src_zcxy).pixels().abs().sum() - optimizer.zero_grad() - loss.backward() - optimizer.step() - scheduler.step(loss) - - return einops.rearrange(inverse_zcxy.detach().pixels(), "Z C X Y -> C X Y Z") - - -@builder.register("gen_biased_perlin_noise_field") -def gen_biased_perlin_noise_field( - shape, - *, - res, - octaves=1, - persistence=0.5, - field_magn_thr_px=1.0, - max_displacement_px=None, - device="cuda", -) -> torch.Tensor: - """Generates a perlin noise vector field with the provided median and maximum vector length.""" - eps = 1e-7 - perlin = rand_perlin_2d_octaves(shape, res, octaves, persistence, device=device) - warp_field = einops.rearrange(perlin, "C X Y Z -> Z C X Y").field_() # type: ignore - - vec_length = warp_field.norm(dim=1, keepdim=True).tensor_() - vec_length_median = torch.median(vec_length) - vec_length_centered = vec_length - vec_length_median - - vec_length_target = torch.where( - vec_length_centered < 0, - vec_length_centered * field_magn_thr_px / abs(vec_length_centered.min()) - + field_magn_thr_px, - vec_length_centered - * (max_displacement_px - field_magn_thr_px) - / abs(vec_length_centered.max()) - + field_magn_thr_px, - ) - - warp_field *= vec_length_target / (vec_length + eps) - return einops.rearrange(warp_field, "Z C X Y -> C X Y Z").tensor_() - - -def get_rigidity_map_zcxy( - field: torch.Tensor, power: float = 2, diagonal_mult: float = 1.0 -) -> torch.Tensor: - # Kernel on Displacement field yields change of displacement - - if field.abs().sum() == 0: - return torch.zeros((field.shape[0], field.shape[2], field.shape[3]), device=field.device) - - batch = field.shape[0] - diff_ker = torch.tensor( - [ - [ - [[0, 0, 0], [-1, 1, 0], [0, 0, 0]], - [[0, -1, 0], [0, 1, 0], [0, 0, 0]], - [[-1, 0, 0], [0, 1, 0], [0, 0, 0]], - [[0, 0, -1], [0, 1, 0], [0, 0, 0]], - ] - ], - dtype=field.dtype, - device=field.device, - ) - - diff_ker = diff_ker.permute(1, 0, 2, 3).repeat(2, 1, 1, 1) - - # Add distance between pixel to get absolute displacement - diff_bias = torch.tensor( - [1.0, 0.0, 1.0, -1.0, 0.0, 1.0, 1.0, 1.0], - dtype=field.dtype, - device=field.device, - ) - delta = torch.conv2d(field, diff_ker, diff_bias, groups=2, padding=[2, 2]) - # delta1 = delta.reshape(2, 4, *delta.shape[-2:]).permute(1, 2, 3, 0) # original - delta = delta.reshape(batch, 2, 4, *delta.shape[-2:]).permute(0, 2, 3, 4, 1) - - # spring_lengths1 = torch.norm(delta1, dim=3) - spring_lengths = torch.norm(delta, dim=-1) - - spring_defs = torch.stack( - [ - spring_lengths[:, 0, 1:-1, 1:-1] - 1, - spring_lengths[:, 0, 1:-1, 2:] - 1, - spring_lengths[:, 1, 1:-1, 1:-1] - 1, - spring_lengths[:, 1, 2:, 1:-1] - 1, - (spring_lengths[:, 2, 1:-1, 1:-1] - 2 ** (1 / 2)) * (diagonal_mult) ** (1 / power), - (spring_lengths[:, 2, 2:, 2:] - 2 ** (1 / 2)) * (diagonal_mult) ** (1 / power), - (spring_lengths[:, 3, 1:-1, 1:-1] - 2 ** (1 / 2)) * (diagonal_mult) ** (1 / power), - (spring_lengths[:, 3, 2:, 0:-2] - 2 ** (1 / 2)) * (diagonal_mult) ** (1 / power), - ] - ) - # Slightly faster than sum() + pow(), and no need for abs() if power is odd - result = torch.norm(spring_defs, p=power, dim=0).pow(power) - - total = 4 + 4 * diagonal_mult - - result /= total - - # Remove incorrect smoothness values caused by 2px zero padding - result[..., 0:2, :] = 0 - result[..., -2:, :] = 0 - result[..., :, 0:2] = 0 - result[..., :, -2:] = 0 - return result - - -@builder.register("get_rigidity_map") -def get_rigidity_map( - field: torch.Tensor, power: float = 2, diagonal_mult: float = 1.0 -) -> torch.Tensor: - field_zcxy = einops.rearrange(field, "C X Y Z -> Z C X Y") - result_zcxy = get_rigidity_map_zcxy( - field_zcxy, power=power, diagonal_mult=diagonal_mult - ).unsqueeze(1) - result = einops.rearrange(result_zcxy, "Z C X Y -> C X Y Z") - return result diff --git a/zetta_utils/alignment/misalignment_detector.py b/zetta_utils/alignment/misalignment_detector.py deleted file mode 100644 index 7a36b4e01..000000000 --- a/zetta_utils/alignment/misalignment_detector.py +++ /dev/null @@ -1,54 +0,0 @@ -import attrs -import einops -import torch -from typeguard import typechecked - -from zetta_utils import builder, convnet - - -@builder.register("naive_misd") -def naive_misd(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: - result = ((src == 0) + (tgt == 0)).byte() - return result - - -@builder.register("MisalignmentDetector") -@typechecked -@attrs.mutable -class MisalignmentDetector: - # Don't create the model during initialization for efficient serialization - model_path: str - - def __call__(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: - if (src != 0).sum() == 0 or (tgt != 0).sum() == 0: - return torch.zeros_like(src[:1], dtype=torch.uint8) - - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # load model during the call _with caching_ - model = convnet.utils.load_model(self.model_path, device=device, use_cache=True) - - assert src.dtype == tgt.dtype - src_zcxy = einops.rearrange(src, "C X Y Z -> Z C X Y").float() - tgt_zcxy = einops.rearrange(tgt, "C X Y Z -> Z C X Y").float() - - if src.dtype == torch.uint8: - data_in = torch.cat((src_zcxy, tgt_zcxy), 1) / 255.0 - elif src.dtype == torch.int8: - data_in = torch.cat((src_zcxy, tgt_zcxy), 1) / 127.0 - - with torch.no_grad(): - result = model(data_in.to(device)) - - result = einops.rearrange(result, "Z C X Y -> C X Y Z") - - assert result.shape[0] == 1 - - assert result.max() <= 1, "Final layer of misalignment detector assumed to be sigmoid" - assert result.min() >= 0, "Final layer of misalignment detector assumed to be sigmoid" - result = 255.0 * result - - return result.round().clamp(0.0, 255.0).byte().to(src.device) diff --git a/zetta_utils/alignment/online_finetuner.py b/zetta_utils/alignment/online_finetuner.py deleted file mode 100644 index ebbd48cdc..000000000 --- a/zetta_utils/alignment/online_finetuner.py +++ /dev/null @@ -1,110 +0,0 @@ -import einops -import metroem -import torch -import torchfields - -from zetta_utils import builder, log, tensor_ops - -logger = log.get_logger("zetta_utils") - - -@builder.register("align_with_online_finetuner") -def align_with_online_finetuner( - src, # (C, X, Y, Z) - tgt, # (C, X, Y, Z) - sm=100, - num_iter=200, - lr=5e-2, - src_field=None, # (C, X, Y, Z) -): - assert src.shape == tgt.shape - # assert len(src.shape) == 4 # (1, C, X, Y,) - # assert src.shape[0] == 1 - src = einops.rearrange(src, "C X Y 1 -> 1 C X Y").float() - tgt = einops.rearrange(tgt, "C X Y 1 -> 1 C X Y").float() - - if src_field is None: - src_field = torch.zeros([1, 2, tgt.shape[2], tgt.shape[3]], device=tgt.device).float() - else: - src_field = einops.rearrange(src_field, "C X Y 1 -> 1 C X Y") - scales = [src.shape[i] / src_field.shape[i] for i in range(2, 4)] - assert scales[0] == scales[1] - src_field = tensor_ops.interpolate( - src_field, scale_factor=scales, mode="field", unsqueeze_input_to=4 - ) - - orig_device = src.device - - if torch.cuda.is_available(): - src = src.cuda() - tgt = tgt.cuda() - src_field = src_field.cuda() - - if src.abs().sum() == 0 or tgt.abs().sum() == 0: - result = torch.zeros_like(src_field) - elif num_iter <= 0: - result = src_field - else: - sm_keys = { - "src": [ - { - "name": "src_zeros", - "fm": 0, - "mask_value": 0.001, - "binarization": {"strat": "eq", "value": 0}, - } - ], - "tgt": [ - { - "name": "tgt_zeros", - "fm": 0, - "mask_value": 0.001, - "binarization": {"strat": "eq", "value": 0}, - } - ], - } - mse_keys = { - "src": [ - { - "name": "src_zeros", - "fm": 0, - "mask_value": 0, - "binarization": {"strat": "eq", "value": 0}, - } - ], - "tgt": [ - { - "name": "tgt_zeros", - "fm": 0, - "mask_value": 0, - "binarization": {"strat": "eq", "value": 0}, - } - ], - } - - with torchfields.set_identity_mapping_cache(True, clear_cache=True): - result = metroem.finetuner.optimize_pre_post_ups( - src, - tgt, - src_field, - src_zeros=(src[:, 0] == 0.0).unsqueeze(1), - tgt_zeros=(tgt[:, 0] == 0.0).unsqueeze(1), - src_defects=torch.zeros((src.shape[0], 1, src.shape[1], src.shape[2])), - tgt_defects=torch.zeros((src.shape[0], 1, src.shape[1], src.shape[2])), - crop=2, - num_iter=num_iter, - lr=lr, - sm=sm, - l2=0, - wd=0, - max_bad=5, - verbose=True, - opt_res_coarsness=0, - normalize=True, - sm_keys_to_apply=sm_keys, - mse_keys_to_apply=mse_keys, - ) - result = einops.rearrange(result, "1 C X Y -> C X Y 1") - result = result.detach().to(orig_device) - result[result.abs() < 0.001] = 0 - return result diff --git a/zetta_utils/alignment/resin_detector.py b/zetta_utils/alignment/resin_detector.py deleted file mode 100644 index a16a0f979..000000000 --- a/zetta_utils/alignment/resin_detector.py +++ /dev/null @@ -1,114 +0,0 @@ -import attrs -import cc3d -import cv2 -import einops -import fastremap -import numpy as np -import torch -from typeguard import typechecked - -from zetta_utils import builder, convnet - - -@builder.register("ResinDetector") -@typechecked -@attrs.mutable -class ResinDetector: - # Input uint8 [ 0 .. 255] - # Output uint8 Prediction [0 .. 255] - - # Don't create the model during initialization for efficient serialization - model_path: str - tile_pad_in: int = 32 - tile_size: int = 448 - tissue_filter_threshold: int = 1000 - resin_filter_threshold: int = 1000 - - def __call__(self, src: torch.Tensor) -> torch.Tensor: - if (src != 0).sum() == 0: - return torch.full_like(src, 255).type(torch.uint8) - else: - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # load model during the call _with caching_ - model = convnet.utils.load_model(self.model_path, device=device, use_cache=True) - - if src.dtype == torch.uint8: - data_in = src.float() / 255.0 # [0.0 .. 1.0] - else: - raise ValueError(f"Unsupported src dtype: {src.dtype}") - - data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y") - data_in = data_in.to(device=device) - with torch.no_grad(): - result = torch.zeros_like( - data_in[ - ..., - : data_in.shape[-2], - : data_in.shape[-1], - ] - ).float() - - tile_pad_out = self.tile_pad_in - - for x in range( - self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size - ): - x_start = x - self.tile_pad_in - x_end = x + self.tile_size + self.tile_pad_in - for y in range( - self.tile_pad_in, data_in.shape[-1] - self.tile_pad_in, self.tile_size - ): - y_start = y - self.tile_pad_in - y_end = y + self.tile_size + self.tile_pad_in - tile = data_in[:, :, x_start:x_end, y_start:y_end] - if (tile != 0).sum() > 0.0: - tile_result = model(tile) - if tile_pad_out > 0: - tile_result = tile_result[ - :, :, tile_pad_out:-tile_pad_out, tile_pad_out:-tile_pad_out - ] - - result[ - :, - :, - x : x + tile_result.shape[-2], - y : y + tile_result.shape[-1], - ] = tile_result - - result = einops.rearrange(result, "Z C X Y -> C X Y Z") - result = torch.sigmoid(result) - pred = ((result > 250.0 / 255.0) * 255).to(dtype=torch.uint8, device="cpu") - - # Background is resin - pred[src == 0.0] = 255 - - # Filter small islands of tissue - tissue = (255 - pred).squeeze().numpy() - tissue = cv2.morphologyEx(tissue, cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) - tissue = cv2.morphologyEx(tissue, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8)) - if self.tissue_filter_threshold > 0: - # TODO: refactor logic with the below & test - islands = cc3d.connected_components(tissue) - uniq, counts = fastremap.unique(islands, return_counts=True) - islands = fastremap.mask( - islands, - [lbl for lbl, cnt in zip(uniq, counts) if cnt < self.tissue_filter_threshold], - ) - tissue[islands == 0] = 0 - - # Filter small islands of resin - resin = 255 - tissue - if self.resin_filter_threshold > 0: - islands = cc3d.connected_components(resin) - uniq, counts = fastremap.unique(islands, return_counts=True) - islands = fastremap.mask( - islands, - [lbl for lbl, cnt in zip(uniq, counts) if cnt < self.resin_filter_threshold], - ) - resin[islands == 0] = 0 - - return torch.from_numpy(resin).reshape(pred.shape)