Skip to content

Commit

Permalink
Fix typing for models.py
Browse files Browse the repository at this point in the history
mihaeladuta committed Nov 14, 2024
1 parent 39c0048 commit b39320a
Showing 2 changed files with 56 additions and 55 deletions.
100 changes: 50 additions & 50 deletions l2gv2/models.py
Original file line number Diff line number Diff line change
@@ -2,14 +2,14 @@
embeddings of a list of patches using VGAE and Node2Vec."""

from typing import Tuple
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch_geometric as tg
from torch_geometric.utils.convert import from_networkx
from torch_geometric.nn import Node2Vec

import local2global as l2g
from l2gv2.patch.patch import Patch
from l2gv2.patch.utils import WeightedAlignmentProblem

@@ -20,12 +20,12 @@ def speye(n: int, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Returns the identity matrix of dimension n as torch.sparse_coo_tensor.
Args:
n (int): dimension of the identity matrix.
n: dimension of the identity matrix.
dtype (torch.dtype, optional): data type of the identity matrix, default is torch.float.
dtype: data type of the identity matrix, default is torch.float.
Returns:
torch.Tensor: identity matrix of dimension n as torch.sparse_coo_tensor.
identity matrix of dimension n as torch.sparse_coo_tensor.
"""

@@ -80,7 +80,7 @@ def forward(self, data: tg.data.Data) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass.
Args:
data(torch_geometric.data.Data): object with edge_index and x attributes.
data: object with edge_index and x attributes.
Returns:
mu: mean of the latent space.
@@ -110,22 +110,22 @@ def train(
"""Train a model on a dataset.
Args:
data (torch_geometric.data.Data): object to train the model on.
data: object to train the model on.
model (torch.nn.Module): the model to train.
model: the model to train.
loss_fun: function that takes the model and the data and returns a scalar loss.
num_epochs (int, optional): number of epochs to train the model, default is 100.
num_epochs: number of epochs to train the model, default is 100.
verbose(bool, optional): if True, print the loss at each epoch, default is True.
verbose: if True, print the loss at each epoch, default is True.
lr (float, optional): learning rate, default is 0.01.
lr: learning rate, default is 0.01.
logger (optional): function that takes the loss and logs it.
logger: function that takes the loss and logs it.
Returns:
torch.nn.Module: trained model.
trained model
"""

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
@@ -154,22 +154,22 @@ def vgae_patch_embeddings(
"""TODO: docstring for `vgae_patch_embeddings`
Args:
patch_data (list): list of torch_geometric.data.Data objects.
patch_data: list of torch_geometric.data.Data objects.
dim (int, optional): dimension of the latent space, default is 100.
dim: dimension of the latent space, default is 100.
hidden_dim (int, optional): hidden dimension of the encoder, default is 32.
hidden_dim: hidden dimension of the encoder, default is 32.
num_epochs (int, optional): number of epochs to train the model, default is 100.
num_epochs: number of epochs to train the model, default is 100.
decoder (optional): decoder function, default is None.
decoder : decoder function, default is None.
lr (float, optional): learning rate, default is 0.01.
lr : learning rate, default is 0.01.
Returns:
list: list of Patch objects.
list of Patch objects.
torch.nn.Module: list of trained models.
list of trained models.
"""

patch_list = []
@@ -197,7 +197,7 @@ def loss_fun(model, data):
model.eval()
coordinates = model.encode(patch).to("cpu").numpy()
models.append(model)
patch_list.append(l2g.Patch(patch.nodes.to("cpu").numpy(), coordinates))
patch_list.append(Patch(patch.nodes.to("cpu").numpy(), coordinates))
return patch_list, models


@@ -216,31 +216,31 @@ def node2vec_(
Args:
data(torch_geometric.data.Data)
data: [description]
emb_dim (int): [description]
emb_dim: [description]
w_length (int, optional): The walk length, default is 20.
w_length: The walk length, default is 20.
c_size (int, optional): Actual context size which is considered for
c_size: Actual context size which is considered for
positive samples, default is 10.
w_per_node (int, optional): The number of walks per node, default is 10.
w_per_node: The number of walks per node, default is 10.
n_negative_samples (int, optional): The number of negative samples to
n_negative_samples: The number of negative samples to
use for each positive sample., default is 1.
p (int, optional): likelihood of immediately revisiting a node in the
p: likelihood of immediately revisiting a node in the
walk, default is 1.
q (int, optional): Control parameter to interpolate between
q: Control parameter to interpolate between
breadth-first strategy and depth-first strategy, default is 1.
num_epoch (int, optional): number of epochs to train the model, default is 100.
num_epoch: number of epochs to train the model, default is 100.
Returns:
Patch:
[description]
"""

node2vec_model = Node2Vec(
@@ -278,7 +278,7 @@ def node2vec_(
node_embeddings = node2vec_model.embedding.weight.data.cpu()
# models.append(model)

return l2g.Patch(data.nodes.to("cpu").numpy(), node_embeddings)
return Patch(data.nodes.to("cpu").numpy(), node_embeddings)


def node2vec_patch_embeddings(
@@ -295,29 +295,29 @@ def node2vec_patch_embeddings(
Args:
patch_data (list) torch_geometric.data.Data objects.
patch_data: torch_geometric.data.Data objects.
emb_dim (int): [description]
emb_dim: [description]
w_length (int, optional): The walk length, default is 20.
w_length: The walk length, default is 20.
c_size (int, optional): Actual context size which is considered for
c_size: Actual context size which is considered for
positive samples, default is 10.
w_per_node (int, optional): The number of walks per node, default is 10.
w_per_node: The number of walks per node, default is 10.
n_negative_samples (int, optional): The number of negative samples to
n_negative_samples: The number of negative samples to
use for each positive sample., default is 1.
p (int, optional): likelihood of immediately revisiting a node in the
p: likelihood of immediately revisiting a node in the
walk, default is 1.
q (int, optional): Control parameter to interpolate between
q: Control parameter to interpolate between
breadth-first strategy and depth-first strategy, default is 1.
Returns:
patch_list: list of Patch objects.
list of Patch objects.
"""

patch_list = []
@@ -343,24 +343,24 @@ def node2vec_patch_embeddings(

def chunk_embedding(
chunk_size: int, patches: list[Patch], dim=2
) -> Tuple[torch.Tensor, WeightedAlignmentProblem]:
) -> Tuple[np.ArrayLike, WeightedAlignmentProblem]:
"""TODO: docstring for `chunk_embedding`
Note: this only works for Autonomous System dataset.
Args:
chunk_size (int): The size of the chunks.
chunk_size: The size of the chunks.
patches (list[Patch]): list of Patch objects.
patches: list of Patch objects.
dim (int, optional): The dimension of the embeddings, default is 2.
dim: The dimension of the embeddings, default is 2.
Returns:
nodes_emb: embeddings of the nodes.
embeddings of the nodes.
prob: WeightedAlignmentProblem object.
WeightedAlignmentProblem object.
"""

@@ -404,7 +404,7 @@ def chunk_embedding(
lr=0.01,
)

prob = l2g.utils.WeightedAlignmentProblem(
prob = WeightedAlignmentProblem(
p_emb[0]
) # embedding of the full graph using embeddings of each patch

@@ -414,7 +414,7 @@ def chunk_embedding(
for i in range(len(emb))
]

prob = l2g.utils.WeightedAlignmentProblem(ppatch2)
prob = WeightedAlignmentProblem(ppatch2)

emb = prob.get_aligned_embedding()

11 changes: 6 additions & 5 deletions l2gv2/patch/patch.py
Original file line number Diff line number Diff line change
@@ -34,12 +34,13 @@ class Patch:
coordinates = None
"""patch embedding coordinates"""

def __init__(self, nodes: iter, coordinates: str=None):
def __init__(self, nodes: iter, coordinates: np.ArrayLike | None = None):
""" Initialise a patch from a list of nodes and corresponding coordinates
Args:
nodes (iter): Iterable of integer node indeces for patch
coordinates (str, optional): filename for coordinate file to be loaded on demand
nodes: Iterable of integer node indeces for patch
coordinates: [description]. Defaults to None.
"""
self.nodes = np.asanyarray(nodes)
self.index = {int(n): i for i, n in enumerate(nodes)}
@@ -62,15 +63,15 @@ def get_coordinates(self, nodes: iter):
""" Get coordinates for a list of nodes
Args:
nodes (iter): Iterable of node indeces
nodes: Iterable of node indeces
"""
return self.coordinates[[self.index[node] for node in nodes], :]

def get_coordinate(self, node: int):
""" Get coordinate for a single node
Args:
node (int): The node index
node: The node index
"""
return self.coordinates[self.index[node], :]

0 comments on commit b39320a

Please sign in to comment.