Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter_all argument to BaseData.snapshot #9966

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

nelsonaloysio
Copy link
Contributor

Adds filter_all argument to BaseData.snapshot, allowing to return both node- and edge-level filtered data.

Currently, whether node- or edge-level data is filtered depends on attr being a node- or edge-level attribute. With this PR, passing filter_all=True calls _select once more on the object, filtering the remaining nodes/edges and attributes.

Example

Code:

import torch
from torch_geometric.datasets import Planetoid

data = Planetoid(root='./data', name='pubmed')[0]

# Default: if `attr` is a node-level attribute, filter only nodes and their attributes.
data.node_time = torch.tensor([0 if x % 2 else 1 for x in range(data.x.shape[0])])
snapshot_node_time = data.snapshot(0, 0, attr='node_time')

# Default: if `attr` is an edge-level attribute, filter only edges and their attributes.
data.edge_time = torch.tensor([0 if x % 2 else 1 for x in range(data.edge_index.shape[1])])
snapshot_edge_time = data.snapshot(0, 0, attr='edge_time')

# Optional: filter both node- and edge-level data if `filter_all` is set as `True`.
snapshot_node_time_ = data.snapshot(0, 0, attr='node_time', filter_all=True)
snapshot_edge_time_ = data.snapshot(0, 0, attr='edge_time', filter_all=True)

print(
    '- Full dataset',
    data,
    "\n- Snapshot (attr='node_time', filter_all=False)",
    snapshot_node_time,
    "\n- Snapshot (attr='node_time', filter_all=True)",
    snapshot_node_time_,
    "\n- Snapshot (attr='edge_time', filter_all=False)",
    snapshot_edge_time,
    "\n- Snapshot (attr='edge_time', filter_all=True)",
    snapshot_edge_time_,
    sep='\n'
)

Output:

- Full dataset
Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717], node_time=[19717], edge_time=[88648])

- Snapshot (attr='node_time', filter_all=False)
Data(x=[9858, 500], edge_index=[2, 88648], y=[9858], train_mask=[9858], val_mask=[9858], test_mask=[9858], node_time=[9858])

- Snapshot (attr='node_time', filter_all=True)
Data(x=[9858, 500], edge_index=[2, 66646], y=[9858], train_mask=[9858], val_mask=[9858], test_mask=[9858], node_time=[9858], edge_time=[66646])

- Snapshot (attr='edge_time', filter_all=False)
Data(x=[19717, 500], edge_index=[2, 44324], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717], node_time=[19717], edge_time=[44324])

- Snapshot (attr='edge_time', filter_all=True)
Data(x=[17477, 500], edge_index=[2, 44324], y=[17477], train_mask=[17477], val_mask=[17477], test_mask=[17477], node_time=[17477], edge_time=[44324])

@nelsonaloysio
Copy link
Contributor Author

By the way: I'm not 100% on the argument name filter_all, but it seemed like the most straightforward choice so far.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant