From 4d090ec5d9c849534f846e857197b78ed774fef1 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 14 Sep 2022 11:57:42 +0100 Subject: [PATCH 1/5] added more gates --- qujax/gates.py | 135 +++++++++++++++++++++++++++++++++++++++++------ qujax/version.py | 2 +- 2 files changed, 120 insertions(+), 17 deletions(-) diff --git a/qujax/gates.py b/qujax/gates.py index 5d3f358..733cbbb 100644 --- a/qujax/gates.py +++ b/qujax/gates.py @@ -2,6 +2,8 @@ I = jnp.eye(2) +_0 = jnp.zeros((2, 2)) + X = jnp.array([[0., 1.], [1., 0.]]) @@ -38,20 +40,45 @@ SXdg = jnp.array([[1. - 1.j, 1. + 1.j], [1. + 1.j, 1. - 1.j]]) / 2 -CX = jnp.array([[1., 0., 0., 0.], - [0., 1., 0., 0.], - [0., 0., 0., 1.], - [0., 0., 1., 0.]]).reshape((2,) * 4) +CX = jnp.block([[I, _0], + [_0, X]]).reshape((2,) * 4) + +CY = jnp.block([[I, _0], + [_0, Y]]).reshape((2,) * 4) + +CZ = jnp.block([[I, _0], + [_0, Z]]).reshape((2,) * 4) + +CH = jnp.block([[I, _0], + [_0, H]]).reshape((2,) * 4) + +CV = jnp.block([[I, _0], + [_0, V]]).reshape((2,) * 4) + +CVdg = jnp.block([[I, _0], + [_0, Vdg]]).reshape((2,) * 4) + +CSX = jnp.block([[I, _0], + [_0, SX]]).reshape((2,) * 4) + +CSXdg = jnp.block([[I, _0], + [_0, SXdg]]).reshape((2,) * 4) + +CCX = jnp.block([[I, _0, _0, _0], # Toffoli gate + [_0, I, _0, _0], + [_0, _0, I, _0], + [_0, _0, _0, X]]).reshape((2,) * 6) -CY = jnp.array([[1., 0., 0., 0.], - [0., 1., 0., 0.], - [0., 0., 0., -1.j], - [0., 0., 1.j, 0.]]).reshape((2,) * 4) +ECR = jnp.block([[_0, Vdg], + [V, _0]]).reshape((2,) * 4) -CZ = jnp.array([[1., 0., 0., 0.], - [0., 1., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., -1.]]).reshape((2,) * 4) +SWAP = jnp.array([[1., 0., 0., 0.], + [0., 0., 1., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1]]) + +CSWAP = jnp.block([[jnp.eye(4), jnp.zeros((4, 4))], + [jnp.zeros((4, 4)), SWAP]]).reshape((2,) * 6) def Rx(param: float) -> jnp.ndarray: @@ -69,14 +96,90 @@ def Rz(param: float) -> jnp.ndarray: return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.j +def CRx(param: float) -> jnp.ndarray: + return jnp.block([[I, _0], + [_0, Rx(param)]]).reshape((2,) * 4) + + +def CRy(param: float) -> jnp.ndarray: + return jnp.block([[I, _0], + [_0, Ry(param)]]).reshape((2,) * 4) + + +def CRz(param: float) -> jnp.ndarray: + return jnp.block([[I, _0], + [_0, Rz(param)]]).reshape((2,) * 4) + + def U1(param: float) -> jnp.ndarray: return U3(0, 0, param) -def U2(param1: float, param2: float) -> jnp.ndarray: - return U3(0.5, param1, param2) +def U2(param0: float, param1: float) -> jnp.ndarray: + return U3(0.5, param0, param1) + + +def U3(param0: float, param1: float, param2: float) -> jnp.ndarray: + return jnp.exp((param1 + param2) * jnp.pi * 1.j / 2) * Rz(param1) @ Ry(param0) @ Rz(param2) -def U3(param1: float, param2: float, param3: float) -> jnp.ndarray: - return jnp.exp((param2 + param3) * jnp.pi * 1.j / 2) * Rz(param2) @ Ry(param1) @ Rz(param3) +def CU1(param: float) -> jnp.ndarray: + return jnp.block([[I, _0], + [_0, U1(param)]]).reshape((2,) * 4) + +def CU2(param0: float, param1: float) -> jnp.ndarray: + return jnp.block([[I, _0], + [_0, U2(param0, param1)]]).reshape((2,) * 4) + + +def CU3(param0: float, param1: float, param2: float) -> jnp.ndarray: + return jnp.block([[I, _0], + [_0, U3(param0, param1, param2)]]).reshape((2,) * 4) + + +def ISWAP(param: float) -> jnp.ndarray: + param_pi_2 = param * jnp.pi / 2 + c = jnp.cos(param_pi_2) + i_s = 1.j * jnp.sin(param_pi_2) + return jnp.array([[1., 0., 0., 0.], + [0., c, i_s, 0.], + [0., i_s, c, 0.], + [0., 0., 0., 1.]]).reshape((2,) * 4) + + +def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray: + param1_pi_2 = param1 * jnp.pi / 2 + c = jnp.cos(param1_pi_2) + i_s_e = 1.j * jnp.sin(param1_pi_2) * jnp.exp(-2.j * jnp.pi * param0) + return jnp.array([[1., 0., 0., 0.], + [0., c, i_s_e, 0.], + [0., i_s_e, c, 0.], + [0., 0., 0., 1.]]).reshape((2,) * 4) + + +def XXPhase(param: float) -> jnp.ndarray: + param_pi_2 = param * jnp.pi / 2 + c = jnp.cos(param_pi_2) + i_s = 1.j * jnp.sin(param_pi_2) + return jnp.array([[c, 0., 0., -i_s], + [0., c, -i_s, 0.], + [0., -i_s, c, 0.], + [-i_s, 0., 0., c]]).reshape((2,) * 4) + + +def YYPhase(param: float) -> jnp.ndarray: + param_pi_2 = param * jnp.pi / 2 + c = jnp.cos(param_pi_2) + i_s = 1.j * jnp.sin(param_pi_2) + return jnp.array([[c, 0., 0., i_s], + [0., c, -i_s, 0.], + [0., -i_s, c, 0.], + [i_s, 0., 0., c]]).reshape((2,) * 4) + + +def ZZPhase(param: float) -> jnp.ndarray: + param_pi_2 = param * jnp.pi / 2 + e_m = jnp.exp(-1.j * param_pi_2) + e_p = jnp.exp(1.j * param_pi_2) + return jnp.diag(jnp.array([e_m, e_p, e_p, e_m])).reshape((2,) * 4) diff --git a/qujax/version.py b/qujax/version.py index 407b8a2..14e974f 100644 --- a/qujax/version.py +++ b/qujax/version.py @@ -1 +1 @@ -__version__ = '0.2.7' +__version__ = '0.2.8' From 7cee1777271241192f5543d8d8803730a5896462 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 14 Sep 2022 13:49:32 +0100 Subject: [PATCH 2/5] sign error --- qujax/gates.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qujax/gates.py b/qujax/gates.py index 733cbbb..6111086 100644 --- a/qujax/gates.py +++ b/qujax/gates.py @@ -151,10 +151,10 @@ def ISWAP(param: float) -> jnp.ndarray: def PhasedISWAP(param0: float, param1: float) -> jnp.ndarray: param1_pi_2 = param1 * jnp.pi / 2 c = jnp.cos(param1_pi_2) - i_s_e = 1.j * jnp.sin(param1_pi_2) * jnp.exp(-2.j * jnp.pi * param0) + i_s = 1.j * jnp.sin(param1_pi_2) return jnp.array([[1., 0., 0., 0.], - [0., c, i_s_e, 0.], - [0., i_s_e, c, 0.], + [0., c, i_s * jnp.exp(2.j * jnp.pi * param0), 0.], + [0., i_s * jnp.exp(-2.j * jnp.pi * param0), c, 0.], [0., 0., 0., 1.]]).reshape((2,) * 4) From 6e7fa188d71eb642c7f0d0d9f110c4c837aa9dc4 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 14 Sep 2022 14:17:53 +0100 Subject: [PATCH 3/5] add unitary check --- qujax/circuit.py | 10 +++------- qujax/circuit_tools.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/qujax/circuit.py b/qujax/circuit.py index d0f0556..1c4bd49 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -95,11 +95,7 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: gate_seq_callable = [] for gate in gate_seq: if isinstance(gate, str): - if gate in gates.__dict__: - gate = gates.__dict__[gate] - else: - raise KeyError(f'Gate string \'{gate}\' not found in qujax.gates ' - f'- consider changing input to an array or callable') + gate = gates.__dict__[gate] if callable(gate): gate_func = gate @@ -109,8 +105,8 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: gate = gate_arr.reshape((2,) * int(jnp.log2(gate_size))) gate_func = _array_to_callable(gate) else: - raise TypeError('Unsupported gate type' - '- gate must be either a string in qujax.gates, an array or callable') + raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' + f'callable: {gate}') gate_seq_callable.append(gate_func) apply_gate_seq = [_get_apply_gate(g, q) for g, q in zip(gate_seq_callable, qubit_inds_seq)] diff --git a/qujax/circuit_tools.py b/qujax/circuit_tools.py index 21b3481..ed1ff0a 100644 --- a/qujax/circuit_tools.py +++ b/qujax/circuit_tools.py @@ -1,9 +1,39 @@ from __future__ import annotations from typing import Sequence, Union, Callable, List, Tuple, Optional import collections.abc +from inspect import signature from jax import numpy as jnp +from qujax import gates + + +def check_unitary(gate: Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]]): + if isinstance(gate, str): + if gate in gates.__dict__: + gate = gates.__dict__[gate] + else: + raise KeyError(f'Gate string \'{gate}\' not found in qujax.gates ' + f'- consider changing input to an array or callable') + + if callable(gate): + num_args = len(signature(gate).parameters) + gate_arr = gate(*jnp.ones(num_args) * 0.1) + elif hasattr(gate, '__array__'): + gate_arr = gate + else: + raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' + f'callable: {gate}') + + gate_square_dim = int(jnp.sqrt(gate_arr.size)) + gate_arr = gate_arr.reshape(gate_square_dim, gate_square_dim) + + if jnp.any(jnp.abs(gate_arr @ jnp.conjugate(gate_arr).T - jnp.eye(gate_square_dim)) > 1e-3): + raise TypeError(f'Gate not unitary: {gate}') + def check_circuit(gate_seq: Sequence[Union[str, jnp.ndarray, @@ -45,6 +75,9 @@ def check_circuit(gate_seq: Sequence[Union[str, if n_qubits is not None and n_qubits < max([max(qi) for qi in qubit_inds_seq]) + 1: raise TypeError('n_qubits must be larger than largest qubit index in qubit_inds_seq') + for g in gate_seq: + check_unitary(g) + def _get_gate_str(gate_obj: Union[str, jnp.ndarray, From f00ced0bb9e29d40e9944ce1a4a38c2313197a15 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 14 Sep 2022 14:26:56 +0100 Subject: [PATCH 4/5] add test_gates --- tests/test_gates.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 tests/test_gates.py diff --git a/tests/test_gates.py b/tests/test_gates.py new file mode 100644 index 0000000..9f41f59 --- /dev/null +++ b/tests/test_gates.py @@ -0,0 +1,9 @@ +from qujax import gates +from qujax.circuit_tools import check_unitary + + +def test_gates(): + for g_str, g in gates.__dict__.items(): + if g_str[0] != '_' and g_str != 'jnp': + check_unitary(g_str) + check_unitary(g) From 870a0ad174c5f05fb08fbbde38dd37773096a3ec Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 14 Sep 2022 14:32:27 +0100 Subject: [PATCH 5/5] check_unitary added to __init__ --- qujax/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qujax/__init__.py b/qujax/__init__.py index 7941cd7..141067b 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -12,6 +12,7 @@ from qujax.observable import sample_integers from qujax.observable import sample_bitstrings +from qujax.circuit_tools import check_unitary from qujax.circuit_tools import check_circuit from qujax.circuit_tools import print_circuit