Skip to content

Commit

Permalink
Merge pull request #34 from Vance-Raiti/main
Browse files Browse the repository at this point in the history
Remove depreciated torchmetrics __init__ argument and fix str_to_one_hot for CUDA devices
  • Loading branch information
lucidrains authored Nov 3, 2023
2 parents 243151b + a46cf5a commit e31cdb4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
5 changes: 2 additions & 3 deletions enformer_pytorch/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

class MeanPearsonCorrCoefPerChannel(Metric):
is_differentiable: Optional[bool] = False
full_state_update:bool = False
higher_is_better: Optional[bool] = True
def __init__(self, n_channels:int, dist_sync_on_step=False):
"""Calculates the mean pearson correlation across channels aggregated over regions"""
super().__init__(dist_sync_on_step=dist_sync_on_step, full_state_update=False)
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.reduce_dims=(0, 1)
self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
Expand Down Expand Up @@ -41,4 +40,4 @@ def compute(self):
pred_var = self.pred_squared - self.count * torch.square(pred_mean)
tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var)
correlation = covariance / tp_var
return correlation
return correlation
3 changes: 2 additions & 1 deletion enformer_pytorch/modeling_enformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,9 @@ def forward(
if isinstance(x, list):
x = str_to_one_hot(x)

elif x.dtype == torch.long:
elif type(x) == torch.Tensor and x.dtype == torch.long:
x = seq_indices_to_one_hot(x)
x.to(self.device)

no_batch = x.ndim == 2

Expand Down

0 comments on commit e31cdb4

Please sign in to comment.