diff --git a/enformer_pytorch/metrics.py b/enformer_pytorch/metrics.py index 427f2a6..a75378b 100644 --- a/enformer_pytorch/metrics.py +++ b/enformer_pytorch/metrics.py @@ -5,11 +5,10 @@ class MeanPearsonCorrCoefPerChannel(Metric): is_differentiable: Optional[bool] = False - full_state_update:bool = False higher_is_better: Optional[bool] = True def __init__(self, n_channels:int, dist_sync_on_step=False): """Calculates the mean pearson correlation across channels aggregated over regions""" - super().__init__(dist_sync_on_step=dist_sync_on_step, full_state_update=False) + super().__init__(dist_sync_on_step=dist_sync_on_step) self.reduce_dims=(0, 1) self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", ) self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", ) @@ -41,4 +40,4 @@ def compute(self): pred_var = self.pred_squared - self.count * torch.square(pred_mean) tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var) correlation = covariance / tp_var - return correlation \ No newline at end of file + return correlation diff --git a/enformer_pytorch/modeling_enformer.py b/enformer_pytorch/modeling_enformer.py index 8420873..7f98ea5 100644 --- a/enformer_pytorch/modeling_enformer.py +++ b/enformer_pytorch/modeling_enformer.py @@ -408,8 +408,9 @@ def forward( if isinstance(x, list): x = str_to_one_hot(x) - elif x.dtype == torch.long: + elif type(x) == torch.Tensor and x.dtype == torch.long: x = seq_indices_to_one_hot(x) + x.to(self.device) no_batch = x.ndim == 2