diff --git a/bactgraph/modeling/model.py b/bactgraph/modeling/model.py index 8c038b7..4ef0ad2 100644 --- a/bactgraph/modeling/model.py +++ b/bactgraph/modeling/model.py @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor): def training_step(self, batch, batch_idx): """Training step.""" x_batch, edge_index_batch, y = batch - x, edge_index = batch_into_single_graph(x_batch, edge_index_batch) + x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch) print(x.shape, edge_index.shape, y.shape) preds = self.forward(x, edge_index) # Squeeze if your output_dim=1 to match target shape @@ -144,7 +144,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): """Validation step""" x_batch, edge_index_batch, y = batch - x, edge_index = batch_into_single_graph(x_batch, edge_index_batch) + x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch) print(x.shape, edge_index.shape, y.shape) preds = self.forward(x, edge_index) @@ -160,7 +160,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx) -> dict: """Test step.""" x_batch, edge_index_batch, y = batch - x, edge_index = batch_into_single_graph(x_batch, edge_index_batch) + x, edge_index, _ = batch_into_single_graph(x_batch, edge_index_batch) preds = self.forward(x, edge_index) loss = F.mse_loss(preds.squeeze(), y.view(-1))