From af163400e8248db3eea2dd1c5feb540b4b0ce99f Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Fri, 18 Oct 2024 13:40:03 -0400 Subject: [PATCH] move internal function of PCA out --- python/tskit/trees.py | 239 ++++++++++++++++++++++++------------------ 1 file changed, 137 insertions(+), 102 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index fe5f4d9ab5..c5435900ba 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8670,106 +8670,6 @@ def pca( ) random_state = np.random.default_rng(random_seed) - - def _rand_pow_range_finder( - operator, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - range_sketch: np.ndarray = None, - ) -> np.ndarray: - """ - Algorithm 9 in https://arxiv.org/pdf/2002.01387 - """ - assert num_vectors >= rank > 0, "num_vectors should be larger than rank" - if range_sketch is None: - test_vectors = rng.normal(size=(operator_dim, num_vectors)) - Q = test_vectors - else: - Q = range_sketch - for _ in range(depth): - Q = np.linalg.qr(Q).Q - Q = operator(Q) - Q = np.linalg.qr(Q).Q - return Q[:, :rank] - - def _rand_svd( - operator, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - range_sketch: np.ndarray = None, - ) -> (np.ndarray, np.ndarray, np.ndarray, float): - """ - Algorithm 8 in https://arxiv.org/pdf/2002.01387 - """ - assert num_vectors >= rank > 0 - Q = _rand_pow_range_finder( - operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch - ) - C = operator(Q).T - U_hat, D, V = np.linalg.svd(C, full_matrices=False) - U = Q @ U_hat - - error_factor = np.power( - 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), - 1 / (2 * depth + 1) - ) - error_bound = D[-1] * (1 + error_factor) - return U[:, :rank], D[:rank], V[:rank], Q, error_bound - - def _genetic_relatedness_vector_individual( - arr: np.ndarray, - centre: bool = True, - windows=None, - ) -> np.ndarray: - ij = np.vstack( - [ - [n, k] - for k, i in enumerate(individuals) - for n in self.individual(i).nodes - ] - ) - samples, sample_individuals = ( - ij[:, 0], - ij[:, 1], - ) # sample node index, individual of those nodes - x = ( - arr - arr.mean(axis=0) if centre else arr - ) # centering within index in rows - x = self.genetic_relatedness_vector( - W=x[sample_individuals], - windows=windows, - mode=mode, - centre=False, - nodes=samples, - )[0] - - def bincount_fn(w): - return np.bincount(sample_individuals, w) - - x = np.apply_along_axis(bincount_fn, axis=0, arr=x) - x = x - x.mean(axis=0) if centre else x # centering within index in cols - - return x - - def _genetic_relatedness_vector_node( - arr: np.ndarray, - centre: bool = True, - windows=None, - ) -> np.ndarray: - x = arr - arr.mean(axis=0) if centre else arr - x = self.genetic_relatedness_vector( - W=x, windows=windows, mode=mode, centre=False, nodes=samples - )[0] - x = x - x.mean(axis=0) if centre else x - - return x - drop_windows = windows is None windows = self.parse_windows(windows) num_windows = len(windows) - 1 @@ -8787,9 +8687,13 @@ def _genetic_relatedness_vector_node( if output_type == "node" else _genetic_relatedness_vector_individual ) - + indices = ( + samples + if output_type == "node" + else individuals + ) def _G(x): - return _f(x, centre=centre, windows=this_window) # NOQA: B023 + return _f(tree_sequence=self, arr=x, indices=indices, mode=mode, centre=centre, windows=this_window) # NOQA: B023 U[i], D[i], _, Q[i], E[i] = _rand_svd( operator=_G, @@ -10387,3 +10291,134 @@ def write_ms( ) else: print(file=output) + +def _rand_pow_range_finder( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> np.ndarray: + """ + Algorithm 9 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0, "num_vectors should be larger than rank" + if range_sketch is None: + test_vectors = rng.normal(size=(operator_dim, num_vectors)) + Q = test_vectors + else: + Q = range_sketch + for _ in range(depth): + Q = np.linalg.qr(Q).Q + Q = operator(Q) + Q = np.linalg.qr(Q).Q + return Q[:, :rank] + +def _rand_svd( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> (np.ndarray, np.ndarray, np.ndarray, float): + """ + Algorithm 8 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0 + Q = _rand_pow_range_finder( + operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch + ) + C = operator(Q).T + U_hat, D, V = np.linalg.svd(C, full_matrices=False) + U = Q @ U_hat + + error_factor = np.power( + 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), + 1 / (2 * depth + 1) + ) + error_bound = D[-1] * (1 + error_factor) + return U[:, :rank], D[:rank], V[:rank], Q, error_bound + +def _genetic_relatedness_vector_individual( + tree_sequence: tskit.TreeSequence, + arr: np.ndarray, + indices: np.ndarray, + mode: str, + centre: bool = True, + windows = None, + ) -> np.ndarray: + ij = np.vstack( + [ + [n, k] + for k, i in enumerate(indices) + for n in tree_sequence.individual(i).nodes + ] + ) + samples, sample_individuals = ( + ij[:, 0], + ij[:, 1], + ) # sample node index, individual of those nodes + x = ( + arr - arr.mean(axis=0) if centre else arr + ) # centering within index in rows + x = tree_sequence.genetic_relatedness_vector( + W=x[sample_individuals], + windows=windows, + mode=mode, + centre=False, + nodes=samples, + )[0] + + def bincount_fn(w): + return np.bincount(sample_individuals, w) + + x = np.apply_along_axis(bincount_fn, axis=0, arr=x) + x = x - x.mean(axis=0) if centre else x # centering within index in cols + + return x + +def _genetic_relatedness_vector_node( + tree_sequence: tskit.TreeSequence, + arr: np.ndarray, + indices: np.ndarray, + mode: str, + centre: bool = True, + windows = None, + ) -> np.ndarray: + x = arr - arr.mean(axis=0) if centre else arr + x = tree_sequence.genetic_relatedness_vector( + W=x, windows=windows, mode=mode, centre=False, nodes=indices, + )[0] + x = x - x.mean(axis=0) if centre else x + + return x + +@dataclass +class PCAResult: + """ + The result of a call to TreeSequence.pca() capturing the output values + and algorithm convergence details. + + + """ + loadings: np.ndarray + """ + The principal component loadings. It is an orthogonal matrix. + """ + eigen_values: np.ndarray + """ + Eigenvalues of the genetic relatedness matrix. + """ + range_sketch: np.ndarray + """ + Range sketch matrix. Can be used as an input for .pca() call with range_sketch option + to further improve precision.. + """ + error_bound: np.ndarray + """ + Error bounds for the eigenvalues. + """