diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml new file mode 100644 index 00000000..220172e3 --- /dev/null +++ b/.github/workflows/build-and-test.yml @@ -0,0 +1,44 @@ +name: build and test + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + if: "!contains(github.event.commits[0].message, '[skip ci]')" + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Install torchdms + run: | + pip install . + - name: Test + run: | + make test + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Lint with pylint + run: | + pylint **/*.py + - name: Check format with black + run: | + black --check torchdms diff --git a/README.md b/README.md index 1cb3adc8..ca41d554 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,22 @@ # torchdms +![build and test](https://github.com/matsengrp/torchdms/workflows/build%20and%20test/badge.svg) [![Docker Repository on Quay](https://quay.io/repository/matsengrp/torchdms/status "Docker Repository on Quay")](https://quay.io/repository/matsengrp/torchdms) -[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) ## What is this? -Pytorch - Deep Mutational Scanning (`torchdms`) is a small Python package made to train neural networks on amino-acid substitution data, predicting some chosen functional score(s). +PyTorch - Deep Mutational Scanning (`torchdms`) is a Python package made to train neural networks on amino-acid substitution data, predicting some chosen functional score(s). We use the binary encoding of variants using [BinaryMap Object](https://jbloomlab.github.io/dms_variants/dms_variants.binarymap.html) as input to feed-forward networks. ## How do I install it? -To install the API and command-line scripts at the moment, it suggested you clone the repository, create a conda environment from `environment.yaml`, and run the tests to make sure everything is working properly. - git clone git@github.com:matsengrp/torchdms.git - conda env create -f environment.yaml - conda activate dms - pytest - -Install with `pip install -e .` + cd torchdms + pip install -r requirements.txt + pip install . + make test ## CLI diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 619608c5..00000000 --- a/environment.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: dms -channels: - - anaconda - - bioconda - - conda-forge - - defaults -dependencies: - - click - - python=3.7 - - pip - - pytorch - - scipy - - seaborn - - matplotlib - - pip: - - black - - click-config-file - - dms_variants - - docformatter - - flake8 - - pylint - - pytest diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..1e47cd34 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +black +click +click-config-file +dms_variants +docformatter +flake8 +matplotlib +pylint +pytest +scipy +torch==1.4.0 diff --git a/torchdms/analysis.py b/torchdms/analysis.py index d4c24060..a7461928 100644 --- a/torchdms/analysis.py +++ b/torchdms/analysis.py @@ -47,6 +47,37 @@ def __init__( ] self.val_loss_record = sys.float_info.max + def loss_of_targets_and_prediction( + self, loss_fn, targets, predictions, per_target_loss_decay + ): + """Return loss on the valid predictions, i.e. the ones that are not + NaN.""" + valid_target_indices = torch.isfinite(targets) + valid_targets = targets[valid_target_indices].to(self.device) + valid_predict = predictions[valid_target_indices].to(self.device) + return loss_fn(valid_targets, valid_predict, per_target_loss_decay) + + def complete_loss(self, loss_fn, targets, predictions, loss_decays): + """Compute our total (across targets) loss with regularization. + + Here we compute loss separately for each target, before summing + the results. This allows for us to take advantage of the samples + which may contain missing information for a subset of the + targets. + """ + per_target_loss = [ + self.loss_of_targets_and_prediction( + loss_fn, + targets[:, target_idx], + predictions[:, target_idx], + per_target_loss_decay, + ) + for target_idx, per_target_loss_decay in zip( + range(targets.shape[1]), loss_decays + ) + ] + return sum(per_target_loss) + self.model.regularization_loss() + def train( self, epoch_count, loss_fn, patience=10, min_lr=1e-5, loss_weight_span=None ): @@ -92,34 +123,6 @@ def loss_decays_of_target_extrema(extremum_pairs_across_targets): scheduler = ReduceLROnPlateau(optimizer, patience=patience, verbose=True) self.model.to(self.device) - def loss_of_targets_and_prediction(targets, predictions, per_target_loss_decay): - """Return loss on the valid predictions, i.e. the ones that are not - NaN.""" - valid_target_indices = torch.isfinite(targets) - valid_targets = targets[valid_target_indices].to(self.device) - valid_predict = predictions[valid_target_indices].to(self.device) - return loss_fn(valid_targets, valid_predict, per_target_loss_decay) - - def complete_loss(targets, predictions, loss_decays): - """Compute our total (across targets) loss with regularization. - - Here we compute loss separately for each target, before - summing the results. This allows for us to take advantage of - the samples which may contain missing information for a - subset of the targets. - """ - per_target_loss = [ - loss_of_targets_and_prediction( - targets[:, target_idx], - predictions[:, target_idx], - per_target_loss_decay, - ) - for target_idx, per_target_loss_decay in zip( - range(target_count), loss_decays - ) - ] - return sum(per_target_loss) + self.model.regularization_loss() - def step_model(): per_epoch_loss = 0.0 for _ in range(batch_count): @@ -133,8 +136,8 @@ def step_model(): samples = batch["samples"].to(self.device) predictions = self.model(samples) - loss = complete_loss( - batch["targets"], predictions, per_stratum_loss_decays + loss = self.complete_loss( + loss_fn, batch["targets"], predictions, per_stratum_loss_decays ) per_batch_loss += loss.item() @@ -153,8 +156,11 @@ def step_model(): val_samples = self.val_data.samples.to(self.device) val_predictions = self.model(val_samples) - val_loss = complete_loss( - self.val_data.targets.to(self.device), val_predictions, val_loss_decay + val_loss = self.complete_loss( + loss_fn, + self.val_data.targets.to(self.device), + val_predictions, + val_loss_decay, ).item() if val_loss < self.val_loss_record: print(f"\nvalidation loss record: {val_loss}") diff --git a/torchdms/cli.py b/torchdms/cli.py index 08115a23..6a53a52b 100644 --- a/torchdms/cli.py +++ b/torchdms/cli.py @@ -108,7 +108,7 @@ def cli(ctx, dry_run): "appended in_test column.", ) @click.option( - "--split-by", + "--partition-by", type=str, required=False, default=None, @@ -125,7 +125,7 @@ def prep( per_stratum_variants_for_test, skip_stratum_if_count_is_smaller_than, export_dataframe, - split_by, + partition_by, ): """Prepare data for training. @@ -139,57 +139,41 @@ def prep( click.echo(f"LOG: Targets: {targets}") click.echo(f"LOG: Loading substitution data for: {in_path}") aa_func_scores, wtseq = from_pickle_file(in_path) - click.echo(f"LOG: Successfully loaded data") + click.echo("LOG: Successfully loaded data") total_variants = len(aa_func_scores.iloc[:, 1]) click.echo(f"LOG: There are {total_variants} total variants in this dataset") - if split_by is None and "library" in aa_func_scores.columns: + if partition_by is None and "library" in aa_func_scores.columns: click.echo( - f"WARNING: you have a 'library' column but haven't specified a split via '--split-by'" + "WARNING: you have a 'library' column but haven't specified a partition " + "via '--partition-by'" ) - if split_by in aa_func_scores.columns: - for split_label, per_split_label_df in aa_func_scores.groupby(split_by): - click.echo(f"LOG: Partitioning data via '{split_label}'") - test_partition, val_partition, partitioned_train_data = partition( - per_split_label_df.copy(), - per_stratum_variants_for_test, - skip_stratum_if_count_is_smaller_than, - export_dataframe, - split_label, - ) - - prep_by_stratum_and_export( - test_partition, - val_partition, - partitioned_train_data, - wtseq, - targets, - out_prefix, - str(ctx.params), - split_label, - ) - - else: - test_partition, val_partition, partitioned_train_data = partition( - aa_func_scores, + def prep_by_stratum_and_export_of_partition_label_and_df(partition_label, df): + split_df = partition( + df, per_stratum_variants_for_test, skip_stratum_if_count_is_smaller_than, export_dataframe, + partition_label, ) prep_by_stratum_and_export( - test_partition, - val_partition, - partitioned_train_data, - wtseq, - targets, - out_prefix, - str(ctx.params), - None, + split_df, wtseq, targets, out_prefix, str(ctx.params), partition_label, ) + if partition_by in aa_func_scores.columns: + for partition_label, per_partition_label_df in aa_func_scores.groupby( + partition_by + ): + click.echo(f"LOG: Partitioning data via '{partition_label}'") + prep_by_stratum_and_export_of_partition_label_and_df( + partition_label, per_partition_label_df.copy() + ) + else: + prep_by_stratum_and_export_of_partition_label_and_df(None, aa_func_scores) + click.echo( "LOG: Successfully finished prep and dumped BinaryMapDataset " f"object to {out_prefix}" @@ -357,7 +341,7 @@ def evaluate(ctx, model_path, data_path, out, device): click.echo(f"LOG: loading testing data from {data_path}") data = from_pickle_file(data_path) - click.echo(f"LOG: evaluating test data with given model") + click.echo("LOG: evaluating test data with given model") evaluation = build_evaluation_dict(model, data.test, device) click.echo(f"LOG: pickle dump evalution data dictionary to {out}") @@ -417,10 +401,10 @@ def scatter(ctx, model_path, data_path, out, device): click.echo(f"LOG: loading testing data from {data_path}") data = from_pickle_file(data_path) - click.echo(f"LOG: evaluating test data with given model") + click.echo("LOG: evaluating test data with given model") evaluation = build_evaluation_dict(model, data.test, device) - click.echo(f"LOG: plotting scatter correlation") + click.echo("LOG: plotting scatter correlation") plot_test_correlation(evaluation, model, out) click.echo(f"LOG: scatter plot finished and dumped to {out}") @@ -448,7 +432,7 @@ def contour(ctx, model_path, start, end, nticks, out): if not isinstance(model, VanillaGGE): raise TypeError("Model must be a VanillaGGE") - click.echo(f"LOG: plotting contour") + click.echo("LOG: plotting contour") latent_space_contour_plot_2d(model, out, start, end, nticks) click.echo(f"LOG: Contour finished and dumped to {out}") @@ -474,7 +458,7 @@ def beta(ctx, model_path, data_path, out): f"LOG: loaded data, evaluating beta coeff for wildtype seq: {data.test.wtseq}" ) - click.echo(f"LOG: plotting beta coefficients") + click.echo("LOG: plotting beta coefficients") beta_coefficients(model, data.test, out) click.echo(f"LOG: Beta coefficients plotted and dumped to {out}") diff --git a/torchdms/data.py b/torchdms/data.py index 0093569e..70ca9c97 100644 --- a/torchdms/data.py +++ b/torchdms/data.py @@ -82,7 +82,20 @@ def target_extrema(self): return [(np.nanmin(column), np.nanmax(column)) for column in numpy_targets.T] -class SplitData: +class SplitDataframe: + """Dataframes for each of test, validation, and train. + + Train is partitioned into a list of dataframes according to the + number of mutations. + """ + + def __init__(self, *, test_data, val_data, train_data_list): + self.test = test_data + self.val = val_data + self.train = train_data_list + + +class SplitDataset: """BinaryMapDatasets for each of test, validation, and train. Train is partitioned into a list of BinaryMapDatasets according to @@ -95,6 +108,21 @@ def __init__(self, *, test_data, val_data, train_data_list, description_string): self.train = train_data_list self.description_string = description_string + @classmethod + def of_split_df(cls, split_df, wtseq, targets, description_string): + def our_of_raw(df): + return BinaryMapDataset.of_raw(df, wtseq=wtseq, targets=targets) + + return cls( + test_data=our_of_raw(split_df.test), + val_data=our_of_raw(split_df.val), + train_data_list=[ + our_of_raw(train_data_partition) + for train_data_partition in split_df.train + ], + description_string=description_string, + ) + @property def labeled_splits(self): """Returns an iterator on (label, split) pairs.""" @@ -107,18 +135,16 @@ def labeled_splits(self): def partition( aa_func_scores, - per_stratum_variants_for_test=100, - skip_stratum_if_count_is_smaller_than=300, - export_dataframe=None, - split_label=None, + per_stratum_variants_for_test, + skip_stratum_if_count_is_smaller_than, + export_dataframe, + partition_label, ): - """Partition the data into a test partition, and a list of training data - partitions. + """Partition the data as needed and build a SplitDataframe. - A "stratum" is a slice of the data with a given number of mutations. - We group training data sets into strata based on their number of - mutations so that the data is presented the neural network with an - even propotion of each. + A "stratum" is a slice of the data with a given number of mutations. We group + training data sets into strata based on their number of mutations so that the data + is presented the neural network with an even proportion of each. Furthermore, we group data rows by unique variants and then split on those grouped items so that we don't have the same variant showing up in train and test. @@ -134,7 +160,7 @@ def partition( ] aa_func_scores["in_test"] = False aa_func_scores["in_val"] = False - partitioned_train_data = [] + test_split_strata = [] for mutation_count, grouped in aa_func_scores.groupby("n_aa_substitutions"): if mutation_count == 0: @@ -166,7 +192,7 @@ def partition( assert not (aa_func_scores["in_test"] & aa_func_scores["in_val"]).any() - partitioned_train_data.append( + test_split_strata.append( aa_func_scores.loc[ (~aa_func_scores["in_test"]) & (~aa_func_scores["in_val"]) @@ -174,87 +200,47 @@ def partition( ].reset_index(drop=True) ) - test_partition = aa_func_scores.loc[aa_func_scores["in_test"],].reset_index( - drop=True - ) - val_partition = aa_func_scores.loc[aa_func_scores["in_val"],].reset_index(drop=True) + test_split = aa_func_scores.loc[aa_func_scores["in_test"],].reset_index(drop=True) + val_split = aa_func_scores.loc[aa_func_scores["in_val"],].reset_index(drop=True) if export_dataframe is not None: - if split_label is not None: - split_label_filename = make_legal_filename(split_label) + if partition_label is not None: + partition_label_filename = make_legal_filename(partition_label) to_pickle_file( - aa_func_scores, f"{export_dataframe}_{split_label_filename}.pkl" + aa_func_scores, f"{export_dataframe}_{partition_label_filename}.pkl" ) else: to_pickle_file(aa_func_scores, f"{export_dataframe}.pkl") - return test_partition, val_partition, partitioned_train_data - - -def prepare( - test_partition, - val_partition, - train_partition_list, - wtseq, - targets, - description_string, -): - """Prepare data for training by splitting into test, val, and train, - partitioning by number of substitutions, and making a SplitData object.""" - - test_data = BinaryMapDataset.of_raw(test_partition, wtseq=wtseq, targets=targets) - val_data = BinaryMapDataset.of_raw(val_partition, wtseq=wtseq, targets=targets) - train_data_list = [ - BinaryMapDataset.of_raw(train_data_partition, wtseq=wtseq, targets=targets) - for train_data_partition in train_partition_list - ] - - return SplitData( - test_data=test_data, - val_data=val_data, - train_data_list=train_data_list, - description_string=description_string, + return SplitDataframe( + test_data=test_split, val_data=val_split, train_data_list=test_split_strata, ) def prep_by_stratum_and_export( - test_partition, - val_partition, - partitioned_train_data, - wtseq, - targets, - out_prefix, - description_string, - split_label, + split_df, wtseq, targets, out_prefix, description_string, partition_label, ): """Print number of training examples per stratum and test samples, run prepare(), and export to .pkl file with descriptive filename.""" - for train_part in partitioned_train_data: + for train_part in split_df.train: num_subs = len(train_part["aa_substitutions"][0].split()) click.echo( f"LOG: There are {len(train_part)} training examples " f"for stratum: {num_subs}" ) - click.echo(f"LOG: There are {len(test_partition)} test points") - click.echo(f"LOG: Successfully partitioned data") - click.echo(f"LOG: preparing binary map dataset") + click.echo(f"LOG: There are {len(split_df.test)} test points") + click.echo("LOG: Successfully partitioned data") + click.echo("LOG: preparing binary map dataset") - if split_label is not None: - split_label_filename = make_legal_filename(split_label) - out_path = f"{out_prefix}_{split_label_filename}.pkl" + if partition_label is not None: + partition_label_filename = make_legal_filename(partition_label) + out_path = f"{out_prefix}_{partition_label_filename}.pkl" else: out_path = f"{out_prefix}.pkl" to_pickle_file( - prepare( - test_partition, - val_partition, - partitioned_train_data, - wtseq, - list(targets), - description_string, - ), + SplitDataset.of_split_df(split_df, wtseq, list(targets), description_string,), out_path, ) diff --git a/torchdms/evaluation.py b/torchdms/evaluation.py index 96a0e6bd..93e56a39 100644 --- a/torchdms/evaluation.py +++ b/torchdms/evaluation.py @@ -1,7 +1,7 @@ """Evaluating models.""" import pandas as pd -from torchdms.data import SplitData +from torchdms.data import SplitDataset from torchdms.utils import positions_in_list QUALITY_CUTOFFS = [-3.0, -1.0] @@ -71,18 +71,18 @@ def error_summary_of_error_df(error_df, model): return error_summary_df -def error_summary_of_data(data, model, split_label=None): +def error_summary_of_data(data, model, partition_label=None): error_df = error_df_of_evaluation_dict(build_evaluation_dict(model, data)) error_summary_df = error_summary_of_error_df(error_df, model) - if split_label is not None: - error_summary_df["split_label"] = split_label + if partition_label is not None: + error_summary_df["partition_label"] = partition_label return error_summary_df -def complete_error_summary(data: SplitData, model): +def complete_error_summary(data: SplitDataset, model): return pd.concat( [ - error_summary_of_data(data, model, split_label) - for split_label, data in data.labeled_splits + error_summary_of_data(data, model, partition_label) + for partition_label, data in data.labeled_splits ] ) diff --git a/torchdms/model.py b/torchdms/model.py index 37579655..cefd1490 100644 --- a/torchdms/model.py +++ b/torchdms/model.py @@ -51,7 +51,7 @@ def __init__( assert len(layer_sizes) == len(activations) - layer_name = f"input_layer" + layer_name = "input_layer" # additive model if len(layer_sizes) == 0: @@ -70,7 +70,7 @@ def __init__( bias = True # final layer - layer_name = f"output_layer" + layer_name = "output_layer" self.layers.append(layer_name) setattr(self, layer_name, nn.Linear(layer_sizes[-1], output_size)) @@ -232,7 +232,7 @@ def model_of_string(model_string, data_path, monotonic_sign): test_dataset = data.test if model_name == "VanillaGGE": if len(layers) == 0: - click.echo(f"LOG: No layers provided, so I'm creating a linear model.") + click.echo("LOG: No layers provided, so I'm creating a linear model.") for layer in layers: if not isinstance(layer, int): raise TypeError("All layer input must be integers") diff --git a/torchdms/plot.py b/torchdms/plot.py index 6fb4932c..ad39fc23 100644 --- a/torchdms/plot.py +++ b/torchdms/plot.py @@ -17,7 +17,6 @@ theme_set, ) import scipy.stats as stats -import torch def plot_error(error_df, out_path, show_points=False): @@ -53,8 +52,8 @@ def plot_test_correlation(evaluation_dict, model, out, cmap="plasma"): targ = evaluation_dict["targets"][:, target] corr = stats.pearsonr(pred, targ) scatter = ax[target].scatter(pred, targ, cmap=cmap, c=n_aa_substitutions, s=8.0) - ax[target].set_xlabel(f"Predicted") - ax[target].set_ylabel(f"Observed") + ax[target].set_xlabel("Predicted") + ax[target].set_ylabel("Observed") target_name = evaluation_dict["target_names"][target] plot_title = f"Test Data for {target_name}\npearsonr = {round(corr[0],3)}" ax[target].set_title(plot_title) @@ -94,38 +93,6 @@ def latent_space_contour_plot_2d(model, out, start=0, end=1000, nticks=100): "it like https://github.com/matsengrp/torchdms/issues/26" ) - num_targets = model.output_size - prediction_matrices = [np.empty([nticks, nticks]) for _ in range(num_targets)] - for i, latent1_value in enumerate(np.linspace(start, end, nticks)): - for j, latent2_value in enumerate(np.linspace(start, end, nticks)): - lat_sample = torch.from_numpy( - np.array([latent1_value, latent2_value]) - ).float() - predictions = model.from_latent(lat_sample) - for pred_idx in range( # pylint: disable=consider-using-enumerate - len(predictions) - ): - prediction_matrices[pred_idx][i][j] = predictions[pred_idx] - - width = 7 * num_targets - fig, ax = plt.subplots(1, num_targets, figsize=(width, 6)) - # Make ax a list even if there's only one target. - if num_targets == 1: - ax = [ax] - for idx, matrix in enumerate(prediction_matrices): - mapp = ax[idx].imshow(matrix) - - # TODO We should have the ticks which show the range of inputs - # matplotlib does not make this obvious. - # ax[idx].set_xticks(ticks=np.linspace(start,end,nticks)) - # ax[idx].set_yticks(np.linspace(start,end,nticks)) - ax[idx].set_xlabel("latent space dimension 1") - ax[idx].set_ylabel("latent space dimension 2") - ax[idx].set_title(f"Prediction Node {idx}\nrange {start} to {end}") - fig.colorbar(mapp, ax=ax[idx], shrink=0.5) - fig.tight_layout() - fig.savefig(out) - def beta_coefficients(model, test_data, out): """This function takes in a (ideally trained) model and plots the values of @@ -139,24 +106,24 @@ def beta_coefficients(model, test_data, out): # below gives us the first transformation matrix of the model # going from inputs -> latent space, thus # a tensor of shape (n latent space dims, n input nodes) - beta_coefficients = next(model.parameters()).data + beta_coefficient_data = next(model.parameters()).data bmap = dms.binarymap.BinaryMap(test_data.original_df,) # To represent the wtseq in the heatmap, create a mask # to encode which matrix entries are the wt nt in each position. wtmask = np.full([len(bmap.alphabet), len(test_data.wtseq)], False, dtype=bool) alphabet = bmap.alphabet - for column_position, nt in enumerate(test_data.wtseq): - row_position = alphabet.index(nt) + for column_position, aa in enumerate(test_data.wtseq): + row_position = alphabet.index(aa) wtmask[row_position, column_position] = True # plot beta's - num_latent_dims = beta_coefficients.shape[0] + num_latent_dims = beta_coefficient_data.shape[0] fig, ax = plt.subplots(num_latent_dims, figsize=(10, 5 * num_latent_dims)) if num_latent_dims == 1: ax = [ax] for latent_dim in range(num_latent_dims): - latent = beta_coefficients[latent_dim].numpy() + latent = beta_coefficient_data[latent_dim].numpy() beta_map = latent.reshape(len(bmap.alphabet), len(test_data.wtseq)) beta_map[wtmask] = np.nan mapp = ax[latent_dim].imshow(beta_map, aspect="auto") diff --git a/torchdms/test/test_data.py b/torchdms/test/test_data.py index 3e715a22..b3900070 100644 --- a/torchdms/test/test_data.py +++ b/torchdms/test/test_data.py @@ -19,11 +19,17 @@ def test_partition_is_clean(): data, _ = from_pickle_file(TEST_DATA_PATH) for seed in range(50): random.seed(seed) - (test, val, partitioned_train) = partition( + split_df = partition( data, per_stratum_variants_for_test=10, skip_stratum_if_count_is_smaller_than=30, + export_dataframe=None, + partition_label=None, + ) + train = pd.concat(split_df.train) + assert set(split_df.test["aa_substitutions"]).isdisjoint( + set(train["aa_substitutions"]) + ) + assert set(split_df.test["aa_substitutions"]).isdisjoint( + set(split_df.val["aa_substitutions"]) ) - train = pd.concat(partitioned_train) - assert set(test["aa_substitutions"]).isdisjoint(set(train["aa_substitutions"])) - assert set(test["aa_substitutions"]).isdisjoint(set(val["aa_substitutions"]))