Skip to content

Commit

Permalink
[tests] Add tests for gmean type 2 and 3
Browse files Browse the repository at this point in the history
  • Loading branch information
tky823 committed Jan 9, 2023
1 parent 148dbd3 commit 8ca8909
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions tests/ssspy/linalg/test_gmean.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import numpy as np
import pytest
from scipy.linalg import sqrtm

from ssspy.linalg import gmeanmh

parameters_type = [1, 2, 3]


def gmeanmh_scipy(A: np.ndarray, B: np.ndarray, inverse="left") -> np.ndarray:
def _sqrtm(X) -> np.ndarray:
Expand All @@ -21,7 +24,8 @@ def _sqrtm(X) -> np.ndarray:
return G


def test_gmean():
@pytest.mark.parametrize("type", parameters_type)
def test_gmean(type: int):
rng = np.random.default_rng(0)
size = (16, 32, 4, 1)

Expand All @@ -34,9 +38,21 @@ def create_psd():
A = create_psd()
B = create_psd()

G1 = gmeanmh(A, B)
G1 = gmeanmh(A, B, type=type)

if type == 1:
assert np.allclose(G1 @ np.linalg.inv(A) @ G1, B)
elif type == 2:
assert np.allclose(G1 @ A @ G1, B)
elif type == 3:
assert np.allclose(G1 @ np.linalg.inv(A) @ G1, np.linalg.inv(B))
else:
raise ValueError("Invalid type={} is given.".format(type))

assert np.allclose(G1 @ np.linalg.inv(A) @ G1, B)
if type == 2:
A = np.linalg.inv(A)
elif type == 3:
B = np.linalg.inv(B)

G2 = gmeanmh_scipy(A, B, inverse="left")
G3 = gmeanmh_scipy(A, B, inverse="right")
Expand Down

0 comments on commit 8ca8909

Please sign in to comment.