Skip to content

Commit

Permalink
removed the box constrait
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 committed May 23, 2024
1 parent 2c7a035 commit 2bb7d26
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 20 deletions.
17 changes: 10 additions & 7 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
import jax

# jax.config.update("jax_enable_x64", True)

# TODO, each of these should be a class that decends from a base model_component class
# and should have a method that returns the function with the parameters as arguments

r"""
+++++++++++++++++++++++++++++
Expand Down Expand Up @@ -337,9 +338,10 @@ def proximal_box_constraints(params, hyperparameters, *args, **kwargs):
params = transform(params, bundle_idxs)

# clamp theta scale to monotonic, and with optional upper bound
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, ge_scale_upper_bound
)
if "ge_scale" in params["theta"]:
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, ge_scale_upper_bound
)
# Any params to constrain during fit
# clamp beta0 for reference condition in non-scaled parameterization
# (where it's a box constraint)
Expand Down Expand Up @@ -393,9 +395,10 @@ def proximal_objective(Dop, params, hyperparameters, scaling=1.0):
params = transform(params, bundle_idxs)

# clamp theta scale to monotonic, and with optional upper bound
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, ge_scale_upper_bound
)
if "ge_scale" in params["theta"]:
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, ge_scale_upper_bound
)
# Any params to constrain during fit
if lock_params is not None:
for (param, subparam), value in lock_params.items():
Expand Down
2 changes: 1 addition & 1 deletion multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
jax.config.update("jax_enable_x64", True)


# TODO why use these when we have the mut parser object? DEPRECATE probably
def split_sub(sub_string):
"""String match the wt, site, and sub aa
in a given string denoting a single substitution
Expand Down Expand Up @@ -525,7 +526,6 @@ def get_nis_from_site_map(site_map):
)

for condition in self._conditions:

# compute times seen in data
# compute the sum of each mutation (column) in the scaled data
times_seen = pd.Series(
Expand Down
15 changes: 7 additions & 8 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,9 @@ def add_phenotypes_to_df(
if phenotype_as_effect:
latent_predictions -= wildtype_df.loc[condition, "predicted_latent"]
latent_predictions[nan_variant_indices] = onp.nan
ret.loc[condition_df.index.values, latent_phenotype_col] = (
latent_predictions
)
ret.loc[
condition_df.index.values, latent_phenotype_col
] = latent_predictions

# func_score predictions on binary variants, X
phenotype_predictions = onp.array(
Expand All @@ -880,9 +880,9 @@ def add_phenotypes_to_df(
condition, "predicted_func_score"
]
phenotype_predictions[nan_variant_indices] = onp.nan
ret.loc[condition_df.index.values, observed_phenotype_col] = (
phenotype_predictions
)
ret.loc[
condition_df.index.values, observed_phenotype_col
] = phenotype_predictions

return ret

Expand Down Expand Up @@ -1085,14 +1085,13 @@ def fit(
upper_bound_ge_scale = 2 * y_range

# box constraints for the reference beta0 parameter.
lock_params[("beta0", self.data.reference)] = 0.0
# lock_params[("beta0", self.data.reference)] = 0.0

compiled_proximal = self._model_components["proximal"]
compiled_objective = jax.jit(self._model_components["objective"])

# if we have more than one condition, we need to set up the ADMM optimization
if len(self.data.conditions) > 1:

non_identical_signs = {
condition: jnp.where(self.data._non_identical_idxs[condition], -1, 1)
for condition in self.data.conditions
Expand Down
1 change: 1 addition & 0 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def stack_fit_models(fit_models_list):
return pd.concat([f.to_frame().T for f in fit_models_list], ignore_index=True)


# TODO make it easier to debug failed fits
def fit_models(params, n_threads=-1, failures="error"):
"""Fit collection of :class:`multidms.model.Model` models.
Expand Down
5 changes: 3 additions & 2 deletions multidms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import jax.numpy as jnp


# TODO add to
# scale_coeff_lasso_shift = lambda x: x.scale_coeff_lasso_shift.apply(lambda x: "{:.2e}".format(x))
# TODO add
# scale_coeff_lasso_shift = lambda x:
# x.scale_coeff_lasso_shift.apply(lambda x: "{:.2e}".format(x))


def difference_matrix(n, ref_index=0):
Expand Down
35 changes: 33 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,38 @@ def test_difference_matrix():
"""


def test_single_condition_fit():
def test_linear_model_fit_simple():
"""
Simple test to see that the linear model
fits without error.
"""
data = multidms.Data(
TEST_FUNC_SCORES.query("condition == 'a'"),
alphabet=multidms.AAS_WITHSTOP,
reference="a",
assert_site_integrity=False,
)
model = multidms.Model(data, multidms.biophysical.identity_activation, PRNGKey=23)
model.fit(maxiter=2, warn_unconverged=False)


def test_linear_model_multi_cond_fit_simple():
"""
Simple test to see that the linear model
fits multiple conditions without error.
"""
data = multidms.Data(
TEST_FUNC_SCORES,
alphabet=multidms.AAS_WITHSTOP,
reference="a",
assert_site_integrity=False,
)
model = multidms.Model(data, multidms.biophysical.identity_activation, PRNGKey=23)

model.fit(maxiter=2, warn_unconverged=False)


def test_fit_simple():
"""
Simple test to see that the single-condition model
fits without error.
Expand All @@ -271,7 +302,7 @@ def test_single_condition_fit():
model.fit(maxiter=2, warn_unconverged=False)


def test_fit_simple():
def test_multi_cond_fit_simple():
"""
Simple test to make sure the multi-condition model
fits without error.
Expand Down

0 comments on commit 2bb7d26

Please sign in to comment.