diff --git a/src/move/tasks/identify_associations.py b/src/move/tasks/identify_associations.py index 91f2660..0aeb129 100644 --- a/src/move/tasks/identify_associations.py +++ b/src/move/tasks/identify_associations.py @@ -120,15 +120,7 @@ def prepare_for_categorical_perturbation( ) dataloaders.append(baseline_dataloader) - target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset) - target_dataset = cat_list[target_dataset_idx] - feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P - feature_mask |= np.sum(target_dataset, axis=2) == 0 - - return ( - dataloaders, - feature_mask, - ) + return dataloaders def prepare_for_continuous_perturbation( @@ -846,10 +838,17 @@ def identify_associations(config: MOVEConfig) -> None: # Identify associations between categorical and continuous features: else: logger.info("Beginning task: identify associations categorical") - ( - dataloaders, - feature_mask, - ) = prepare_for_categorical_perturbation( + task_config = cast(IdentifyAssociationsConfig, config.task) + target_dataset_idx = config.data.categorical_names.index( + task_config.target_dataset + ) + target_dataset = cat_list[target_dataset_idx] + mappings = io.load_mappings(interim_path / "mappings.json") + target_mapping = mappings[task_config.target_dataset] + target_value = one_hot_encode_single(target_mapping, task_config.target_value) + feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P + feature_mask |= np.sum(target_dataset, axis=2) == 0 + dataloaders = prepare_for_categorical_perturbation( config, interim_path, baseline_dataloader, cat_list )