Skip to content

Commit

Permalink
fix message passing (geoelements#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhsiao93 authored and yjchoi1 committed Jan 8, 2025
1 parent 9f38928 commit 1e206cd
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 3 deletions.
17 changes: 14 additions & 3 deletions gns/graph_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def forward(
# Start propagating messages.
# Takes in the edge indices and all additional data which is needed to
# construct messages and to update node embeddings.
# Call PyG propagate() method:
# 1. Message phase - compute messages for each edge
# 2. Aggregate phase - aggregate messages for each node
# 3. Update phase - updates only the node features
# Update uses the message from step 1 and any original arguments passed to
# propagate() to update the node embeddings. This is why we need to store
# the updated edge features to return them from the update() method.
x, edge_features = self.propagate(
edge_index=edge_index, x=x, edge_features=edge_features
)
Expand All @@ -212,8 +219,8 @@ def message(
"""
# Concat edge features with a final shape of [nedges, latent_dim*3]
edge_features = torch.cat([x_i, x_j, edge_features], dim=-1)
edge_features = self.edge_fn(edge_features)
return edge_features
self._edge_features = self.edge_fn(edge_features) # Create and store
return self._edge_features # This gets passed to aggregate()

def update(
self, x_updated: torch.tensor, x: torch.tensor, edge_features: torch.tensor
Expand All @@ -233,9 +240,13 @@ def update(
"""
# Concat node features with a final shape of
# [nparticles, latent_dim (or nnode_in) *2]
# This gets called later, after message() and aggregate()
# Update modified from MessagePassing takes the output of aggregation
# as first argument and any argument which was initially passed to
# propagate hence we need to return the stored value of edge_features
x_updated = torch.cat([x_updated, x], dim=-1)
x_updated = self.node_fn(x_updated)
return x_updated, edge_features
return x_updated, self._edge_features


class Processor(MessagePassing):
Expand Down
78 changes: 78 additions & 0 deletions test/test_message_edge_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from gns.graph_network import *
import torch
from torch_geometric.data import Data
import pytest


@pytest.fixture
def interaction_network_data():
model = InteractionNetwork(
nnode_in=2,
nnode_out=2,
nedge_in=2,
nedge_out=2,
nmlp_layers=2,
mlp_hidden_dim=2,
)
edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float) # node features
edge_attr = torch.tensor([[1, 1], [2, 2]], dtype=torch.float) # edge features

return model, x, edge_index, edge_attr


def test_edge_update(interaction_network_data):
"""Test if edge features are updated and finite and are not simply doubled"""
model, x, edge_index, edge_attr = interaction_network_data
old_edge_attr = edge_attr.clone() # Save the old edge features

# One message passing step
_, updated_edge_attr = model(x=x, edge_index=edge_index, edge_features=edge_attr)

# Check if edge features shape is correct
assert (
edge_attr.shape == old_edge_attr.shape
), f"Edge features shape is not preserved, changed from {old_edge_attr.shape} to {edge_attr.shape}"
# Check if edge features are updated
assert not torch.equal(
updated_edge_attr, old_edge_attr * 2
), "Edge features are simply doubled"
assert not torch.equal(
updated_edge_attr, old_edge_attr
), "Edge features are not updated"
# Check if edge features are finite
assert torch.all(torch.isfinite(edge_attr)), "Edge features are not finite"


def test_gradients_computed(interaction_network_data):
"""Test if gradients are computed and finite"""
model, x, edge_index, edge_attr = interaction_network_data
x.requires_grad = True
edge_attr.requires_grad = True

# First pass
aggr, updated_edge_features = model(
x=x, edge_index=edge_index, edge_features=edge_attr
)
updated_node_features = x + aggr
# Second pass
aggr, updated_edge_features = model(
x=updated_node_features,
edge_index=edge_index,
edge_features=updated_edge_features,
)
updated_node_features = updated_node_features + aggr
# Compute loss
loss = (updated_edge_features).sum()
loss.backward()

# Check if gradients are computed
assert x.grad is not None, "Gradients for node features are not computed"
assert edge_attr.grad is not None, "Gradients for edge features are not computed"
# Check if gradients are finite
assert torch.all(
torch.isfinite(x.grad)
), "Gradients for node features are not finite"
assert torch.all(
torch.isfinite(edge_attr.grad)
), "Gradients for edge features are not finite"

0 comments on commit 1e206cd

Please sign in to comment.