Skip to content

Commit

Permalink
New simulations (#134)
Browse files Browse the repository at this point in the history
* correlations analysis, validation loss working

* 10 default shifted sites

* Final polish

* Formatting, tinker with groupby.apply 2.2.0 depr.

* new float32 doctest updates

* added new docs link
  • Loading branch information
jgallowa07 authored Feb 21, 2024
1 parent 9d7456d commit 37e55f4
Show file tree
Hide file tree
Showing 10 changed files with 9,665 additions and 3,468 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ and how much the effects differ between experiments.

installation
biophysical_model
simulation_validation
fit_delta_BA1_example
multidms
acknowledgments
Expand Down
3 changes: 3 additions & 0 deletions docs/simulation_validation.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../notebooks/simulation_validation.ipynb"
}
2 changes: 1 addition & 1 deletion multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from jaxopt.prox import prox_lasso
import jax

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_enable_x64", True)


r"""
Expand Down
3 changes: 0 additions & 3 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@

from multidms import AAS

import jax
import jax.numpy as jnp
import seaborn as sns
from jax.experimental import sparse
from matplotlib import pyplot as plt
from pandarallel import pandarallel

jax.config.update("jax_enable_x64", True)


def split_sub(sub_string):
"""String match the wt, site, and sub aa
Expand Down
138 changes: 107 additions & 31 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ class Model:
>>> model.get_mutations_df() # doctest: +NORMALIZE_WHITESPACE
beta shift_b predicted_func_score_a predicted_func_score_b \
mutation
M1E 0.080868 0.0 0.101030 0.565154
M1W -0.386247 0.0 -0.476895 -0.012770
G3P -0.375656 0.0 -0.464124 0.000000
G3R 1.668974 0.0 1.707195 2.171319
M1E 1.816086 0.0 1.800479 1.379661
M1W -0.754885 0.0 -0.901211 -1.322029
G3P 0.339889 0.0 0.420818 0.000000
G3R -0.534835 0.0 -0.653051 -1.073869
<BLANKLINE>
times_seen_a times_seen_b wts sites muts
mutation
Expand All @@ -160,29 +160,26 @@ class Model:
>>> model.get_variants_df() # doctest: +NORMALIZE_WHITESPACE
condition aa_substitutions func_score var_wrt_ref predicted_latent \
0 a M1E 2.0 M1E 0.080868
1 a G3R -7.0 G3R 1.668974
2 a G3P -0.5 G3P -0.375656
3 a M1W 2.3 M1W -0.386247
4 b M1E 1.0 G3P M1E 0.080868
5 b P3R -5.0 G3R 2.044630
6 b P3G 0.4 0.375656
7 b M1E P3G 2.7 M1E 0.456523
8 b M1E P3R -2.7 G3R M1E 2.125498
0 a M1E 2.0 M1E 1.816086
1 a G3R -7.0 G3R -0.534835
2 a G3P -0.5 G3P 0.339889
3 a M1W 2.3 M1W -0.754885
4 b M1E 1.0 G3P M1E 1.816086
5 b P3R -5.0 G3R -0.874724
6 b P3G 0.4 -0.339889
7 b M1E P3G 2.7 M1E 1.476197
8 b M1E P3R -2.7 G3R M1E 0.941362
<BLANKLINE>
predicted_func_score
0 0.101030
1 1.707195
2 -0.464124
3 -0.476895
4 0.098285
5 2.171319
6 0.464124
7 0.565154
8 2.223789
0 1.800479
1 -0.653051
2 0.420818
3 -0.901211
4 1.560311
5 -1.073869
6 -0.420818
7 1.379661
8 0.992495
We now have access to the predicted (and gamma corrected) functional scores
as predicted by the models current parameters.
Expand All @@ -192,13 +189,13 @@ class Model:
given our initialized parameters
>>> model.loss
Array(7.19312981, dtype=float64)
Array(4.7124467, dtype=float32)
Next, we fit the model with some chosen hyperparameters.
>>> model.fit(maxiter=1000, lasso_shift=1e-5)
>>> model.loss
Array(1.18200934e-05, dtype=float64)
Array(6.0517805e-06, dtype=float32)
The model tunes its parameters in place, and the subsequent call to retrieve
the loss reflects our models loss given its updated parameters.
Expand Down Expand Up @@ -333,7 +330,7 @@ def __repr__(self):
"""Returns a string representation of the object."""
return f"{self.__class__.__name__}({self.name})"

def _str__(self):
def __str__(self):
"""Returns a string representation of the object."""
return f"{self.__class__.__name__}({self.name})"

Expand Down Expand Up @@ -373,6 +370,7 @@ def loss(self) -> float:
"scale_coeff_ridge_beta": 0.0,
"scale_coeff_ridge_shift": 0.0,
"scale_coeff_ridge_gamma": 0.0,
"scale_ridge_alpha_d": 0.0,
}
data = (self.data.training_data["X"], self.data.training_data["y"])
return jax.jit(self.model_components["objective"])(self.params, data, **kwargs)
Expand Down Expand Up @@ -548,6 +546,82 @@ def get_mutations_df(

return mutations_df[col_order]

def get_df_loss(self, df, error_if_unknown=False, verbose=False):
"""
Get the loss of the model on a given data frame.
Parameters
----------
df : pandas.DataFrame
Data frame containing variants. Requirements are the same as
those used to initialize the `multidms.Data` object - except
the indices must be unique.
error_if_unknown : bool
If some of the substitutions in a variant are not present in
the model (not in :attr:`AbstractEpistasis.binarymap`)
then by default we do not include those variants
in the loss calculation. If `True`, raise an error.
verbose : bool
If True, print the number of valid and invalid variants.
Returns
-------
float
The loss of the model on the given data frame.
"""
substitutions_col = "aa_substitutions"
condition_col = "condition"
func_score_col = "func_score"
ref_bmap = self.data.binarymaps[self.data.reference]

if substitutions_col not in df.columns:
raise ValueError("`df` lacks `substitutions_col` " f"{substitutions_col}")
if condition_col not in df.columns:
raise ValueError("`df` lacks `condition_col` " f"{condition_col}")

X, y = {}, {}
for condition, condition_df in df.groupby(condition_col):
variant_subs = condition_df[substitutions_col]
if condition not in self.data.reference_sequence_conditions:
variant_subs = condition_df.apply(
lambda x: self.data.convert_subs_wrt_ref_seq(
condition, x[substitutions_col]
),
axis=1,
)

# build binary variants as csr matrix, make prediction, and append
valid, invalid = 0, 0 # row indices of elements that are one
binary_variants = []
variant_targets = []

for subs, target in zip(variant_subs, condition_df[func_score_col]):
try:
binary_variants.append(ref_bmap.sub_str_to_binary(subs))
variant_targets.append(target)
valid += 1
except ValueError:
if error_if_unknown:
raise ValueError(
"Variant has substitutions not in model:"
f"\n{subs}\nMaybe use `unknown_as_nan`?"
)
else:
invalid += 1

if verbose:
print(
f"condition: {condition}, n valid variants: "
f"{valid}, n invalid variants: {invalid}"
)

X[condition] = sparse.BCOO.from_scipy_sparse(
scipy.sparse.csr_matrix(onp.vstack(binary_variants))
)
y[condition] = jnp.array(variant_targets)

return self.model_components["objective"](self.params, (X, y))

def add_phenotypes_to_df(
self,
df,
Expand Down Expand Up @@ -1149,9 +1223,11 @@ def plot_shifts_by_site(
**kwargs,
)
color = [
self.data.condition_colors[condition]
if s not in self.data.non_identical_sites[condition]
else (0.0, 0.0, 0.0)
(
self.data.condition_colors[condition]
if s not in self.data.non_identical_sites[condition]
else (0.0, 0.0, 0.0)
)
for s in mutation_site_summary_df.sites
]
size = [
Expand Down
Loading

0 comments on commit 37e55f4

Please sign in to comment.