-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGLSTM_training.py
34 lines (30 loc) · 1.17 KB
/
GLSTM_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
def train2(model, data, criterion, optimizer):
model.train() # set model to training mode
optimizer.zero_grad() # clear gradients
out = model(data.x, data.edge_index)#forward pass
loss = criterion(out, data.y)#calculate loss
loss.backward() #back propagation
# Accumulate gradients for plotting
gradients = []
param_names = []
for name, param in model.named_parameters():
if param.grad is not None:
gradients.append(param.grad.norm().item()) # Store gradient norm
param_names.append(name)
optimizer.step() #update weights
pred = out.argmax(dim=1)
correct = pred.eq(data.y).sum().item()
accuracy = correct / data.num_nodes
return loss.item(), accuracy, gradients, param_names
# Evaluation function
def evaluate(model, data, criterion, optimizer):
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index) #forward pass
loss = criterion(out, data.y) #calculate loss
pred = out.argmax(dim=1)
correct = pred.eq(data.y).sum().item()
accuracy = correct / data.num_nodes
return loss.item(), accuracy