From 49a93d03fb4063f20e7e72e97aa046bd2072ac70 Mon Sep 17 00:00:00 2001 From: Henry Date: Tue, 9 Jul 2024 19:55:52 +0200 Subject: [PATCH] :zap: do not build dataloaders for multiprocessing --- src/move/tasks/identify_associations.py | 34 +++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/move/tasks/identify_associations.py b/src/move/tasks/identify_associations.py index 0aeb129..7c5e225 100644 --- a/src/move/tasks/identify_associations.py +++ b/src/move/tasks/identify_associations.py @@ -829,11 +829,17 @@ def identify_associations(config: MOVEConfig) -> None: logger.info(f"Perturbation type: {task_config.target_value}") output_subpath = Path(output_path) / "perturbation_visualization" output_subpath.mkdir(exist_ok=True, parents=True) - - dataloaders = prepare_for_continuous_perturbation( - config, output_subpath, baseline_dataloader - ) + if not task_config.multiprocess: + dataloaders = prepare_for_continuous_perturbation( + config, output_subpath, baseline_dataloader + ) feature_mask = nan_mask + con_dataset_names = config.data.continuous_names + target_idx = con_dataset_names.index( + task_config.target_dataset + ) # dataset index + logger.debug(f"Cont. shapes: {baseline_dataset.con_shapes} [take {target_idx}]") + num_perturbed = baseline_dataset.con_shapes[target_idx] # Identify associations between categorical and continuous features: else: @@ -848,13 +854,21 @@ def identify_associations(config: MOVEConfig) -> None: target_value = one_hot_encode_single(target_mapping, task_config.target_value) feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P feature_mask |= np.sum(target_dataset, axis=2) == 0 - dataloaders = prepare_for_categorical_perturbation( - config, interim_path, baseline_dataloader, cat_list + if not task_config.multiprocess: + dataloaders = prepare_for_categorical_perturbation( + config, interim_path, baseline_dataloader, cat_list + ) + num_perturbed = target_dataset.shape[-1] + logger.info( + f"Cat. shapes: {baseline_dataset.cat_shapes}" + f" [take {target_dataset_idx}]" ) + target_shape = baseline_dataset.cat_shapes[target_dataset_idx] + num_perturbed = target_shape[0] # APPROACH EVALUATION ########################## - num_perturbed = len(dataloaders) - 1 # P - logger.debug(f"# perturbed features: {num_perturbed}") + # num_perturbed = len(dataloaders) - 1 # P + # logger.debug(f"# perturbed features: {num_perturbed}") if task_type == "bayes": task_config = cast(IdentifyAssociationsBayesConfig, task_config) @@ -895,7 +909,7 @@ def identify_associations(config: MOVEConfig) -> None: elif task_type == "ttest": task_config = cast(IdentifyAssociationsTTestConfig, task_config) if task_config.multiprocess: - logger.warning("Multiprocessing is not supported for ttest approach.") + raise NotImplementedError("Multiprocessing is not supported for T-test.") sig_ids, *extra_cols = _ttest_approach( task_config, train_dataloader, @@ -915,7 +929,7 @@ def identify_associations(config: MOVEConfig) -> None: elif task_type == "ks": task_config = cast(IdentifyAssociationsKSConfig, task_config) if task_config.multiprocess: - logger.warning("Multiprocessing is not supported for KS approach.") + raise NotImplementedError("Multiprocessing is not supported for KS.") sig_ids, *extra_cols = _ks_approach( config, task_config,