forked from BenediktRiegel/quantum-no-free-lunch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassic_training.py
42 lines (36 loc) · 1.42 KB
/
classic_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import time
import numpy as np
import torch
from cost_modifying_functions import identity
# torch.manual_seed(4241)
# np.random.seed(4241)
def cost_func(X, y_conj, qnn, device='cpu'):
"""
Compute cost function based on the circuit in Fig. 5 in Sharma et al.
"""
V = qnn.get_tensor_V()
dot_products = torch.sum(torch.mul(torch.matmul(V, X), y_conj), dim=[1, 2])
cost = (torch.sum(torch.square(dot_products.real)) + torch.sum(torch.square(dot_products.imag))) / X.shape[0]
return 1 - cost
def train(X, unitary, qnn, num_epochs, optimizer, scheduler=None, device='cpu', cost_modification=None):
if cost_modification is None:
cost_modification = identity
losses = []
y_conj = torch.matmul(unitary, X).conj()
i = 0
for i in range(num_epochs):
loss = cost_modification(cost_func(X, y_conj, qnn, device=device))
losses.append(loss.item())
if i % 100 == 0:
print(f"\tepoch [{i+1}/{num_epochs}] loss={loss.item()}")
if loss.item() == 0.0:
# print(f"epoch [{i+1}/{num_epochs}] loss={loss.item()}\nstopped")
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step(loss.item())
# print(f"\tepoch [{i + 1}/{num_epochs}] lr={scheduler.get_lr()}")
print(f"\tepoch [{i+1}/{num_epochs}] final loss {losses[-1]}")
return losses