From 2bb7d266c94254e77bcfb8351759fc97a050c664 Mon Sep 17 00:00:00 2001 From: jgallowa07 Date: Thu, 23 May 2024 11:00:29 -0700 Subject: [PATCH] removed the box constrait --- multidms/biophysical.py | 17 ++++++++++------- multidms/data.py | 2 +- multidms/model.py | 15 +++++++-------- multidms/model_collection.py | 1 + multidms/utils.py | 5 +++-- tests/test_data.py | 35 +++++++++++++++++++++++++++++++++-- 6 files changed, 55 insertions(+), 20 deletions(-) diff --git a/multidms/biophysical.py b/multidms/biophysical.py index 56c416f..e43a34f 100644 --- a/multidms/biophysical.py +++ b/multidms/biophysical.py @@ -41,7 +41,8 @@ import jax # jax.config.update("jax_enable_x64", True) - +# TODO, each of these should be a class that decends from a base model_component class +# and should have a method that returns the function with the parameters as arguments r""" +++++++++++++++++++++++++++++ @@ -337,9 +338,10 @@ def proximal_box_constraints(params, hyperparameters, *args, **kwargs): params = transform(params, bundle_idxs) # clamp theta scale to monotonic, and with optional upper bound - params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip( - 0, ge_scale_upper_bound - ) + if "ge_scale" in params["theta"]: + params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip( + 0, ge_scale_upper_bound + ) # Any params to constrain during fit # clamp beta0 for reference condition in non-scaled parameterization # (where it's a box constraint) @@ -393,9 +395,10 @@ def proximal_objective(Dop, params, hyperparameters, scaling=1.0): params = transform(params, bundle_idxs) # clamp theta scale to monotonic, and with optional upper bound - params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip( - 0, ge_scale_upper_bound - ) + if "ge_scale" in params["theta"]: + params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip( + 0, ge_scale_upper_bound + ) # Any params to constrain during fit if lock_params is not None: for (param, subparam), value in lock_params.items(): diff --git a/multidms/data.py b/multidms/data.py index 5ae4679..45b225c 100644 --- a/multidms/data.py +++ b/multidms/data.py @@ -32,6 +32,7 @@ jax.config.update("jax_enable_x64", True) +# TODO why use these when we have the mut parser object? DEPRECATE probably def split_sub(sub_string): """String match the wt, site, and sub aa in a given string denoting a single substitution @@ -525,7 +526,6 @@ def get_nis_from_site_map(site_map): ) for condition in self._conditions: - # compute times seen in data # compute the sum of each mutation (column) in the scaled data times_seen = pd.Series( diff --git a/multidms/model.py b/multidms/model.py index 6921995..92a683d 100644 --- a/multidms/model.py +++ b/multidms/model.py @@ -866,9 +866,9 @@ def add_phenotypes_to_df( if phenotype_as_effect: latent_predictions -= wildtype_df.loc[condition, "predicted_latent"] latent_predictions[nan_variant_indices] = onp.nan - ret.loc[condition_df.index.values, latent_phenotype_col] = ( - latent_predictions - ) + ret.loc[ + condition_df.index.values, latent_phenotype_col + ] = latent_predictions # func_score predictions on binary variants, X phenotype_predictions = onp.array( @@ -880,9 +880,9 @@ def add_phenotypes_to_df( condition, "predicted_func_score" ] phenotype_predictions[nan_variant_indices] = onp.nan - ret.loc[condition_df.index.values, observed_phenotype_col] = ( - phenotype_predictions - ) + ret.loc[ + condition_df.index.values, observed_phenotype_col + ] = phenotype_predictions return ret @@ -1085,14 +1085,13 @@ def fit( upper_bound_ge_scale = 2 * y_range # box constraints for the reference beta0 parameter. - lock_params[("beta0", self.data.reference)] = 0.0 + # lock_params[("beta0", self.data.reference)] = 0.0 compiled_proximal = self._model_components["proximal"] compiled_objective = jax.jit(self._model_components["objective"]) # if we have more than one condition, we need to set up the ADMM optimization if len(self.data.conditions) > 1: - non_identical_signs = { condition: jnp.where(self.data._non_identical_idxs[condition], -1, 1) for condition in self.data.conditions diff --git a/multidms/model_collection.py b/multidms/model_collection.py index 625084f..186ba76 100644 --- a/multidms/model_collection.py +++ b/multidms/model_collection.py @@ -191,6 +191,7 @@ def stack_fit_models(fit_models_list): return pd.concat([f.to_frame().T for f in fit_models_list], ignore_index=True) +# TODO make it easier to debug failed fits def fit_models(params, n_threads=-1, failures="error"): """Fit collection of :class:`multidms.model.Model` models. diff --git a/multidms/utils.py b/multidms/utils.py index 64cfdd5..2b33c62 100644 --- a/multidms/utils.py +++ b/multidms/utils.py @@ -11,8 +11,9 @@ import jax.numpy as jnp -# TODO add to -# scale_coeff_lasso_shift = lambda x: x.scale_coeff_lasso_shift.apply(lambda x: "{:.2e}".format(x)) +# TODO add +# scale_coeff_lasso_shift = lambda x: +# x.scale_coeff_lasso_shift.apply(lambda x: "{:.2e}".format(x)) def difference_matrix(n, ref_index=0): diff --git a/tests/test_data.py b/tests/test_data.py index 8ec6dfb..36f6145 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -255,7 +255,38 @@ def test_difference_matrix(): """ -def test_single_condition_fit(): +def test_linear_model_fit_simple(): + """ + Simple test to see that the linear model + fits without error. + """ + data = multidms.Data( + TEST_FUNC_SCORES.query("condition == 'a'"), + alphabet=multidms.AAS_WITHSTOP, + reference="a", + assert_site_integrity=False, + ) + model = multidms.Model(data, multidms.biophysical.identity_activation, PRNGKey=23) + model.fit(maxiter=2, warn_unconverged=False) + + +def test_linear_model_multi_cond_fit_simple(): + """ + Simple test to see that the linear model + fits multiple conditions without error. + """ + data = multidms.Data( + TEST_FUNC_SCORES, + alphabet=multidms.AAS_WITHSTOP, + reference="a", + assert_site_integrity=False, + ) + model = multidms.Model(data, multidms.biophysical.identity_activation, PRNGKey=23) + + model.fit(maxiter=2, warn_unconverged=False) + + +def test_fit_simple(): """ Simple test to see that the single-condition model fits without error. @@ -271,7 +302,7 @@ def test_single_condition_fit(): model.fit(maxiter=2, warn_unconverged=False) -def test_fit_simple(): +def test_multi_cond_fit_simple(): """ Simple test to make sure the multi-condition model fits without error.