Skip to content

Commit

Permalink
⚡ do not build dataloaders for multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry committed Jul 9, 2024
1 parent 8c4e53b commit 49a93d0
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,17 @@ def identify_associations(config: MOVEConfig) -> None:
logger.info(f"Perturbation type: {task_config.target_value}")
output_subpath = Path(output_path) / "perturbation_visualization"
output_subpath.mkdir(exist_ok=True, parents=True)

dataloaders = prepare_for_continuous_perturbation(
config, output_subpath, baseline_dataloader
)
if not task_config.multiprocess:
dataloaders = prepare_for_continuous_perturbation(
config, output_subpath, baseline_dataloader
)
feature_mask = nan_mask
con_dataset_names = config.data.continuous_names
target_idx = con_dataset_names.index(
task_config.target_dataset
) # dataset index
logger.debug(f"Cont. shapes: {baseline_dataset.con_shapes} [take {target_idx}]")
num_perturbed = baseline_dataset.con_shapes[target_idx]

# Identify associations between categorical and continuous features:
else:
Expand All @@ -848,13 +854,21 @@ def identify_associations(config: MOVEConfig) -> None:
target_value = one_hot_encode_single(target_mapping, task_config.target_value)
feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P
feature_mask |= np.sum(target_dataset, axis=2) == 0
dataloaders = prepare_for_categorical_perturbation(
config, interim_path, baseline_dataloader, cat_list
if not task_config.multiprocess:
dataloaders = prepare_for_categorical_perturbation(
config, interim_path, baseline_dataloader, cat_list
)
num_perturbed = target_dataset.shape[-1]
logger.info(
f"Cat. shapes: {baseline_dataset.cat_shapes}"
f" [take {target_dataset_idx}]"
)
target_shape = baseline_dataset.cat_shapes[target_dataset_idx]
num_perturbed = target_shape[0]

# APPROACH EVALUATION ##########################
num_perturbed = len(dataloaders) - 1 # P
logger.debug(f"# perturbed features: {num_perturbed}")
# num_perturbed = len(dataloaders) - 1 # P
# logger.debug(f"# perturbed features: {num_perturbed}")

if task_type == "bayes":
task_config = cast(IdentifyAssociationsBayesConfig, task_config)
Expand Down Expand Up @@ -895,7 +909,7 @@ def identify_associations(config: MOVEConfig) -> None:
elif task_type == "ttest":
task_config = cast(IdentifyAssociationsTTestConfig, task_config)
if task_config.multiprocess:
logger.warning("Multiprocessing is not supported for ttest approach.")
raise NotImplementedError("Multiprocessing is not supported for T-test.")
sig_ids, *extra_cols = _ttest_approach(
task_config,
train_dataloader,
Expand All @@ -915,7 +929,7 @@ def identify_associations(config: MOVEConfig) -> None:
elif task_type == "ks":
task_config = cast(IdentifyAssociationsKSConfig, task_config)
if task_config.multiprocess:
logger.warning("Multiprocessing is not supported for KS approach.")
raise NotImplementedError("Multiprocessing is not supported for KS.")
sig_ids, *extra_cols = _ks_approach(
config,
task_config,
Expand Down

0 comments on commit 49a93d0

Please sign in to comment.