Skip to content

Commit

Permalink
- add support for removing hu moments
Browse files Browse the repository at this point in the history
- add config example
  • Loading branch information
shaikh58 committed Nov 27, 2024
1 parent 0668076 commit dc5a6f0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions dreem/models/visual_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def forward(self, img: torch.Tensor) -> torch.Tensor:
class DescriptorVisualEncoder(torch.nn.Module):
"""Visual Encoder based on image descriptors"""

def __init__(self, **kwargs):
def __init__(self, use_hu_moments: bool = False, **kwargs):
super().__init__()
self.use_hu_moments = use_hu_moments

def compute_hu_moments(self, img):
mu = measure.moments_central(img)
Expand All @@ -194,7 +195,8 @@ def forward(self, img: torch.Tensor) -> torch.Tensor:

inertia_tensor = self.compute_inertia_tensor(im)
mean_intensity = im.mean()
# hu_moments = self.compute_hu_moments(im)
if self.use_hu_moments:
hu_moments = self.compute_hu_moments(im)

# Flatten inertia tensor
inertia_tensor_flat = inertia_tensor.flatten()
Expand All @@ -204,7 +206,7 @@ def forward(self, img: torch.Tensor) -> torch.Tensor:
[
inertia_tensor_flat,
[mean_intensity],
# hu_moments
hu_moments if self.use_hu_moments else [],
]
)

Expand Down
2 changes: 1 addition & 1 deletion dreem/training/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ model:
# backend: "timm"
# pretrained: false
# descriptor:
# source: "skimage"
# use_hu_moments: false
d_model: 1024
nhead: 8
num_encoder_layers: 1
Expand Down

0 comments on commit dc5a6f0

Please sign in to comment.