diff --git a/uncoverml/validate.py b/uncoverml/validate.py index 4fbffd42..9708b7d3 100644 --- a/uncoverml/validate.py +++ b/uncoverml/validate.py @@ -647,7 +647,8 @@ def plot_feature_correlation_matrix(config: Config, x_all, oos_val): def validation_scatter(config: Config, y_true, predictions, oos_val): - scores_file = os.path.join(config.output_dir, config.name + "_validation_scores.json") + file_name_suffix = '_oos' if oos_val else '' + scores_file = os.path.join(config.output_dir, config.name + f"_validation_scores{file_name_suffix}.json") with open(scores_file, 'r') as f: scores = json.load(f)