Skip to content

Commit

Permalink
Merge pull request #197 from BQSKit/canonical_unitary
Browse files Browse the repository at this point in the history
Support for Canonical Unitaries
  • Loading branch information
mtweiden authored Nov 1, 2023
2 parents 854e4c4 + 3dcba54 commit 69ff181
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
34 changes: 34 additions & 0 deletions bqskit/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,37 @@ def compute_su_generators(n: int) -> npt.NDArray[np.complex128]:
t3 *= np.sqrt(2 / (n * (n - 1)))
generators.append(t3)
return np.array(generators, dtype=np.complex128)


def canonical_unitary(
unitary: npt.NDArray[np.complex128],
) -> npt.NDArray[np.complex128]:
"""
Computes a canonical form for the provided unitary.
If unitary matrices V, W differ only by a global phase, then
canonical_unitary(V) == canonical_unitary(W).
Args:
unitary (npt.NDArray[np.complex128]): A unitary matrix.
Returns:
npt.NDArray[np.complex128]: A unitary matrix.
References:
https://arxiv.org/abs/2306.05622
"""
determinant = np.linalg.det(unitary)
dimension = len(unitary)
# Compute special unitary
global_phase = np.angle(determinant) / dimension
global_phase = global_phase % (2 * np.pi / dimension)
global_phase_factor = np.exp(-1j * global_phase)
special_unitary = global_phase_factor * unitary
# Standardize speical unitary to account for exp(-i2pi/N) differences
first_row_mags = np.linalg.norm(special_unitary[0, :], ord=2)
index = np.argmax(first_row_mags)
std_phase = np.angle(special_unitary[0, index])
correction_phase = 0 - std_phase
std_correction = np.exp(1j * correction_phase)
return std_correction * special_unitary
21 changes: 21 additions & 0 deletions tests/utils/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import unitary_group

from bqskit.qis.pauli import PauliMatrices
from bqskit.utils.math import canonical_unitary
from bqskit.utils.math import dexpmv
from bqskit.utils.math import dot_product
from bqskit.utils.math import pauli_expansion
Expand Down Expand Up @@ -185,3 +186,23 @@ def test_valid(self, reH: npt.NDArray[np.complex128]) -> None:
print(alpha)
H = PauliMatrices(int(np.log2(reH.shape[0]))).dot_product(alpha)
assert np.linalg.norm(H - reH) < 1e-16


class TestCanonicalUnitary:
@pytest.mark.parametrize(
'phase, num_qudits',
[
(np.exp(1j * 2 * np.pi * np.random.randn()), qudits)
for qudits in range(1, 6) for _ in range(100)
],
)
def test_canonical_unitary(
self,
phase: np.complex128,
num_qudits: int,
) -> None:
base_unitary = unitary_group.rvs(2**num_qudits)
canon_unitary = canonical_unitary(base_unitary)
phased_unitary = phase * base_unitary
recanon_unitary = canonical_unitary(phased_unitary)
assert np.allclose(canon_unitary, recanon_unitary, atol=1e-5)

0 comments on commit 69ff181

Please sign in to comment.