Skip to content

Commit

Permalink
Add .test for Normal distribution (#1)
Browse files Browse the repository at this point in the history
* initial version

* debug score test and add docstring

* simplify null model fitting

* Ran Black formatter

* store variance instead of std
  • Loading branch information
jykr authored Dec 20, 2024
1 parent 25f96b6 commit 3703cef
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
41 changes: 41 additions & 0 deletions SpatialDE/_internal/score_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,44 @@ def _grad_negative_negbinom_loglik(params, counts, sizefactors):
+ (counts - mus) / one_alpha_mu
) # d/d_alpha
return -tf.convert_to_tensor((grad0, grad1))


class NormalScoreTest(ScoreTest):
@dataclass
class NullModel(ScoreTest.NullModel):
mu: tf.Tensor
sigmasq: tf.Tensor

def _fit_null(self, y: tf.Tensor) -> NullModel:
return self.NullModel(tf.reduce_mean(y), tf.reduce_variance(y))

def _test(
self, y: tf.Tensor, nullmodel: NullModel
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
return self._do_test(
self._K,
to_default_float(y),
to_default_float(nullmodel.sigmasq),
to_default_float(nullmodel.mu),
)

@staticmethod
@tf.function(experimental_compile=True)
def _do_test(
K: tf.Tensor, rawy: tf.Tensor, sigmasq: tf.Tensor, mu: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
W = 1 / sigmasq # W^-1
stat = 0.5 * tf.reduce_sum(
(rawy - mu) * W * tf.tensordot(K, W * (rawy - mu), axes=(-1, -1)), axis=-1
)

P = tf.linalg.diag(W) - W[:, tf.newaxis] * W[tf.newaxis, :] / tf.reduce_sum(W)
PK = W[:, tf.newaxis] * K - W[:, tf.newaxis] * ((W[tf.newaxis, :] @ K) / tf.reduce_sum(W))
trace_PK = tf.linalg.trace(PK)
e_tilde = 0.5 * trace_PK
I_tau_tau = 0.5 * tf.reduce_sum(PK * PK, axis=(-2, -1))
I_tau_theta = 0.5 * tf.reduce_sum(PK * P, axis=(-2, -1))
I_theta_theta = 0.5 * tf.reduce_sum(tf.square(P), axis=(-2, -1))
I_tau_tau_tilde = I_tau_tau - tf.square(I_tau_theta) / I_theta_theta

return stat, e_tilde, I_tau_tau_tilde
19 changes: 12 additions & 7 deletions SpatialDE/de_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
import warnings
from itertools import zip_longest
from typing import Optional, Dict, Tuple, Union, List
from typing import Optional, Dict, Tuple, Union, List, Literal

import numpy as np
import pandas as pd
Expand All @@ -16,6 +16,7 @@
from ._internal.util import bh_adjust, calc_sizefactors, default_kernel_space, kspace_walk
from ._internal.score_test import (
NegativeBinomialScoreTest,
NormalScoreTest,
combine_pvalues,
)
from ._internal.tf_dataset import AnnDataDataset
Expand Down Expand Up @@ -63,6 +64,7 @@ def test(
kernel_space: Optional[Dict[str, Union[float, List[float]]]] = None,
sizefactors: Optional[np.ndarray] = None,
stack_kernels: Optional[bool] = None,
obs_dist: Literal["NegativeBinomial", "Normal"] = "NegativeBinomial",
use_cache: bool = True,
) -> Tuple[pd.DataFrame, Union[pd.DataFrame, None]]:
"""
Expand Down Expand Up @@ -94,6 +96,7 @@ def test(
the kernel matrices. This leads to increased memory consumption, but will drastically improve runtime
on GPUs for smaller data sets. Defaults to ``True`` for datasets with less than 2000 observations and
``False`` otherwise.
obs_dist: Distribution of the observations. If set as "Normal", model the regression to have Gaussian mean field error with identity link function.
use_cache: Whether to use a pre-computed distance matrix for all kernels instead of computing the distance
matrix anew for each kernel. Increases memory consumption, but is somewhat faster.
Expand All @@ -111,19 +114,21 @@ def test(
sizefactors = calc_sizefactors(adata)
if kernel_space is None:
kernel_space = default_kernel_space(dcache)

individual_results = None if omnibus else []
if stack_kernels is None and adata.n_obs <= 2000 or stack_kernels or omnibus:
kernels = []
kernelnames = []
for k, name in kspace_walk(kernel_space, dcache):
kernels.append(k)
kernelnames.append(name)
test = NegativeBinomialScoreTest(
sizefactors,
omnibus,
kernels,
)
if obs_dist == "NegativeBinomial":
test = NegativeBinomialScoreTest(
sizefactors,
omnibus,
kernels,
)
else:
test = NormalScoreTest(omnibus, kernels)

results = []
with tqdm(total=adata.n_vars) as pbar:
Expand Down

0 comments on commit 3703cef

Please sign in to comment.