Skip to content

Commit

Permalink
move internal function of PCA out
Browse files Browse the repository at this point in the history
  • Loading branch information
hanbin973 committed Oct 18, 2024
1 parent 1a8ff7a commit af16340
Showing 1 changed file with 137 additions and 102 deletions.
239 changes: 137 additions & 102 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -8670,106 +8670,6 @@ def pca(
)

random_state = np.random.default_rng(random_seed)

def _rand_pow_range_finder(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> np.ndarray:
"""
Algorithm 9 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0, "num_vectors should be larger than rank"
if range_sketch is None:
test_vectors = rng.normal(size=(operator_dim, num_vectors))
Q = test_vectors
else:
Q = range_sketch
for _ in range(depth):
Q = np.linalg.qr(Q).Q
Q = operator(Q)
Q = np.linalg.qr(Q).Q
return Q[:, :rank]

def _rand_svd(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> (np.ndarray, np.ndarray, np.ndarray, float):
"""
Algorithm 8 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0
Q = _rand_pow_range_finder(
operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch
)
C = operator(Q).T
U_hat, D, V = np.linalg.svd(C, full_matrices=False)
U = Q @ U_hat

error_factor = np.power(
1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)),
1 / (2 * depth + 1)
)
error_bound = D[-1] * (1 + error_factor)
return U[:, :rank], D[:rank], V[:rank], Q, error_bound

def _genetic_relatedness_vector_individual(
arr: np.ndarray,
centre: bool = True,
windows=None,
) -> np.ndarray:
ij = np.vstack(
[
[n, k]
for k, i in enumerate(individuals)
for n in self.individual(i).nodes
]
)
samples, sample_individuals = (
ij[:, 0],
ij[:, 1],
) # sample node index, individual of those nodes
x = (
arr - arr.mean(axis=0) if centre else arr
) # centering within index in rows
x = self.genetic_relatedness_vector(
W=x[sample_individuals],
windows=windows,
mode=mode,
centre=False,
nodes=samples,
)[0]

def bincount_fn(w):
return np.bincount(sample_individuals, w)

x = np.apply_along_axis(bincount_fn, axis=0, arr=x)
x = x - x.mean(axis=0) if centre else x # centering within index in cols

return x

def _genetic_relatedness_vector_node(
arr: np.ndarray,
centre: bool = True,
windows=None,
) -> np.ndarray:
x = arr - arr.mean(axis=0) if centre else arr
x = self.genetic_relatedness_vector(
W=x, windows=windows, mode=mode, centre=False, nodes=samples
)[0]
x = x - x.mean(axis=0) if centre else x

return x

drop_windows = windows is None
windows = self.parse_windows(windows)
num_windows = len(windows) - 1
Expand All @@ -8787,9 +8687,13 @@ def _genetic_relatedness_vector_node(
if output_type == "node"
else _genetic_relatedness_vector_individual
)

indices = (
samples
if output_type == "node"
else individuals
)
def _G(x):
return _f(x, centre=centre, windows=this_window) # NOQA: B023
return _f(tree_sequence=self, arr=x, indices=indices, mode=mode, centre=centre, windows=this_window) # NOQA: B023

U[i], D[i], _, Q[i], E[i] = _rand_svd(
operator=_G,
Expand Down Expand Up @@ -10387,3 +10291,134 @@ def write_ms(
)
else:
print(file=output)

def _rand_pow_range_finder(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> np.ndarray:
"""
Algorithm 9 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0, "num_vectors should be larger than rank"
if range_sketch is None:
test_vectors = rng.normal(size=(operator_dim, num_vectors))
Q = test_vectors
else:
Q = range_sketch
for _ in range(depth):
Q = np.linalg.qr(Q).Q
Q = operator(Q)
Q = np.linalg.qr(Q).Q
return Q[:, :rank]

def _rand_svd(
operator,
operator_dim: int,
rank: int,
depth: int,
num_vectors: int,
rng: np.random.Generator,
range_sketch: np.ndarray = None,
) -> (np.ndarray, np.ndarray, np.ndarray, float):
"""
Algorithm 8 in https://arxiv.org/pdf/2002.01387
"""
assert num_vectors >= rank > 0
Q = _rand_pow_range_finder(
operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch
)
C = operator(Q).T
U_hat, D, V = np.linalg.svd(C, full_matrices=False)
U = Q @ U_hat

error_factor = np.power(
1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)),
1 / (2 * depth + 1)
)
error_bound = D[-1] * (1 + error_factor)
return U[:, :rank], D[:rank], V[:rank], Q, error_bound

def _genetic_relatedness_vector_individual(
tree_sequence: tskit.TreeSequence,
arr: np.ndarray,
indices: np.ndarray,
mode: str,
centre: bool = True,
windows = None,
) -> np.ndarray:
ij = np.vstack(
[
[n, k]
for k, i in enumerate(indices)
for n in tree_sequence.individual(i).nodes
]
)
samples, sample_individuals = (
ij[:, 0],
ij[:, 1],
) # sample node index, individual of those nodes
x = (
arr - arr.mean(axis=0) if centre else arr
) # centering within index in rows
x = tree_sequence.genetic_relatedness_vector(
W=x[sample_individuals],
windows=windows,
mode=mode,
centre=False,
nodes=samples,
)[0]

def bincount_fn(w):
return np.bincount(sample_individuals, w)

x = np.apply_along_axis(bincount_fn, axis=0, arr=x)
x = x - x.mean(axis=0) if centre else x # centering within index in cols

return x

def _genetic_relatedness_vector_node(
tree_sequence: tskit.TreeSequence,
arr: np.ndarray,
indices: np.ndarray,
mode: str,
centre: bool = True,
windows = None,
) -> np.ndarray:
x = arr - arr.mean(axis=0) if centre else arr
x = tree_sequence.genetic_relatedness_vector(
W=x, windows=windows, mode=mode, centre=False, nodes=indices,
)[0]
x = x - x.mean(axis=0) if centre else x

return x

@dataclass
class PCAResult:
"""
The result of a call to TreeSequence.pca() capturing the output values
and algorithm convergence details.
"""
loadings: np.ndarray
"""
The principal component loadings. It is an orthogonal matrix.
"""
eigen_values: np.ndarray
"""
Eigenvalues of the genetic relatedness matrix.
"""
range_sketch: np.ndarray
"""
Range sketch matrix. Can be used as an input for .pca() call with range_sketch option
to further improve precision..
"""
error_bound: np.ndarray
"""
Error bounds for the eigenvalues.
"""

0 comments on commit af16340

Please sign in to comment.