Skip to content

Commit

Permalink
🎨 move feat_mask creation out
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry committed Jul 9, 2024
1 parent f895237 commit 5eb7954
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,7 @@ def prepare_for_categorical_perturbation(
)
dataloaders.append(baseline_dataloader)

target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset)
target_dataset = cat_list[target_dataset_idx]
feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P
feature_mask |= np.sum(target_dataset, axis=2) == 0

return (
dataloaders,
feature_mask,
)
return dataloaders


def prepare_for_continuous_perturbation(
Expand Down Expand Up @@ -846,10 +838,17 @@ def identify_associations(config: MOVEConfig) -> None:
# Identify associations between categorical and continuous features:
else:
logger.info("Beginning task: identify associations categorical")
(
dataloaders,
feature_mask,
) = prepare_for_categorical_perturbation(
task_config = cast(IdentifyAssociationsConfig, config.task)
target_dataset_idx = config.data.categorical_names.index(
task_config.target_dataset
)
target_dataset = cat_list[target_dataset_idx]
mappings = io.load_mappings(interim_path / "mappings.json")
target_mapping = mappings[task_config.target_dataset]
target_value = one_hot_encode_single(target_mapping, task_config.target_value)
feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P
feature_mask |= np.sum(target_dataset, axis=2) == 0
dataloaders = prepare_for_categorical_perturbation(
config, interim_path, baseline_dataloader, cat_list
)

Expand Down

0 comments on commit 5eb7954

Please sign in to comment.