Skip to content

Commit

Permalink
Merge pull request #111 from arnor-sigurdsson/update-attribution-dloa…
Browse files Browse the repository at this point in the history
…der-setup

Update attribution data loader setup, other minor fixes
  • Loading branch information
arnor-sigurdsson authored Jan 23, 2025
2 parents df65a98 + 655ba6f commit 75d9615
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 146 deletions.
11 changes: 10 additions & 1 deletion eir/interpretation/interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,16 @@ def check_if_should_run_analysis(


def get_background_loader(experiment: "Experiment") -> torch.utils.data.DataLoader:
background_loader = copy.deepcopy(experiment.train_loader)
original_loader = experiment.train_loader
shuffle = isinstance(original_loader.sampler, torch.utils.data.RandomSampler)

background_loader = torch.utils.data.DataLoader(
dataset=original_loader.dataset,
batch_size=original_loader.batch_size,
shuffle=shuffle,
num_workers=original_loader.num_workers,
pin_memory=original_loader.pin_memory,
)

return background_loader

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def predict_survival_wrapper_with_labels(
)

else:
risk_scores = model_outputs.numpy()
risk_scores = model_outputs.cpu().numpy()

baseline_hazard = output_object.baseline_hazard
unique_times = output_object.baseline_unique_times
Expand Down
21 changes: 21 additions & 0 deletions eir/setup/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,24 @@ def validate_tensor_broker_configs(
f"The following tensor messages are used but not cached: "
f"{unused_needing_cache}"
)


def validate_global_input_config_sync(
global_config: schemas.GlobalConfig,
input_configs: Sequence[schemas.InputConfig],
) -> None:
will_compute_attrs = global_config.aa.compute_attributions
if not will_compute_attrs:
return

for input_config in input_configs:
input_type_info = input_config.input_type_info
match input_type_info:
case schemas.OmicsInputDataConfig():
if input_type_info.snp_file is None:
raise ValueError(
"To compute attributions, the input config must include the "
"path for the snp_file parameter (a .bim file). Kindly "
"fill in the snp_file parameter in the input config"
"with the path to the .bim file."
)
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def save_survival_evaluation_results_wrapper(
)

else:
risk_scores = model_outputs.numpy()
risk_scores = model_outputs.cpu().numpy()

unique_times, baseline_hazard = estimate_baseline_hazard(
times=times,
Expand Down
174 changes: 87 additions & 87 deletions poetry.lock

Large diffs are not rendered by default.

65 changes: 31 additions & 34 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
adabelief-pytorch==0.2.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
aioboto3==13.4.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
aiobotocore[boto3]==2.18.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
aiobotocore==2.18.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
aiofiles==24.1.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
aiohappyeyeballs==2.4.4 ; python_full_version >= "3.12.0" and python_version < "3.13"
aiohttp==3.11.11 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -23,11 +23,11 @@ autograd==1.7.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
babel==2.16.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
beautifulsoup4==4.12.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
black==24.10.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
bleach[css]==6.2.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
bleach==6.2.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
blessed==1.20.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
boto3==1.36.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
botocore==1.36.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
cachetools==5.5.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
cachetools==5.5.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
captum==0.7.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
certifi==2024.12.14 ; python_full_version >= "3.12.0" and python_version < "3.13"
cffi==1.17.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -40,22 +40,21 @@ comm==0.2.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
configargparse==1.7 ; python_full_version >= "3.12.0" and python_version < "3.13"
contourpy==1.3.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
coverage==7.6.10 ; python_full_version >= "3.12.0" and python_version < "3.13"
coverage[toml]==7.6.10 ; python_full_version >= "3.12.0" and python_version < "3.13"
cycler==0.12.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
debugpy==1.8.12 ; python_full_version >= "3.12.0" and python_version < "3.13"
decorator==5.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
deeplake==4.1.4 ; python_full_version >= "3.12.0" and python_version < "3.13"
deeplake==4.1.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
defusedxml==0.7.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
deptry==0.21.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
distlib==0.3.9 ; python_full_version >= "3.12.0" and python_version < "3.13"
docutils==0.21.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
einops==0.8.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
executing==2.1.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
executing==2.2.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
fastapi==0.115.6 ; python_full_version >= "3.12.0" and python_version < "3.13"
fastjsonschema==2.21.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
filelock==3.16.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
filelock==3.17.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
flake8==7.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
fonttools==4.55.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
fonttools==4.55.4 ; python_full_version >= "3.12.0" and python_version < "3.13"
formulaic==1.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
fqdn==1.5.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
frozenlist==1.5.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -67,8 +66,8 @@ h11==0.14.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
httpcore==1.0.7 ; python_full_version >= "3.12.0" and python_version < "3.13"
httpx==0.28.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
huggingface-hub==0.27.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
hypothesis==6.124.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
identify==2.6.5 ; python_full_version >= "3.12.0" and python_version < "3.13"
hypothesis==6.124.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
identify==2.6.6 ; python_full_version >= "3.12.0" and python_version < "3.13"
idna==3.10 ; python_full_version >= "3.12.0" and python_version < "3.13"
imagesize==1.4.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
iniconfig==2.0.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -87,7 +86,6 @@ json5==0.10.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
jsonpointer==3.0.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
jsonschema==4.23.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
jsonschema[format-nongpl]==4.23.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyter-client==8.6.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyter-console==6.6.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyter-core==5.7.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand Down Expand Up @@ -121,20 +119,20 @@ networkx==3.4.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
nodeenv==1.9.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
notebook-shim==0.2.4 ; python_full_version >= "3.12.0" and python_version < "3.13"
notebook==7.3.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
numpy==2.1.3 ; python_version >= "3.12" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cuda-cupti-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cuda-runtime-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cufft-cu12==11.2.1.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-curand-cu12==10.3.5.147 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cusolver-cu12==11.6.1.9 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cusparse-cu12==12.3.1.170 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
numpy==2.1.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-cupti-cu12==12.4.127 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-runtime-cu12==12.4.127 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cudnn-cu12==9.1.0.70 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cufft-cu12==11.2.1.3 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-curand-cu12==10.3.5.147 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusolver-cu12==11.6.1.9 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparse-cu12==12.3.1.170 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-ml-py==12.560.30 ; python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-nccl-cu12==2.21.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-nvjitlink-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-nvtx-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0" and python_version < "3.13"
nvidia-nccl-cu12==2.21.5 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvjitlink-cu12==12.4.127 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvtx-cu12==12.4.127 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
overrides==7.7.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
packaging==24.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
pandas==2.2.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -147,12 +145,12 @@ pillow==11.1.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
platformdirs==4.3.6 ; python_full_version >= "3.12.0" and python_version < "3.13"
pluggy==1.5.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
polars==1.20.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
pre-commit==4.0.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
pre-commit==4.1.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
prometheus-client==0.21.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
prompt-toolkit==3.0.49 ; python_full_version >= "3.12.0" and python_version < "3.13"
prompt-toolkit==3.0.50 ; python_full_version >= "3.12.0" and python_version < "3.13"
propcache==0.2.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
psutil==6.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
ptyprocess==0.7.0 ; python_full_version >= "3.12.0" and python_version < "3.13" and (sys_platform != "win32" and sys_platform != "emscripten" or os_name != "nt")
ptyprocess==0.7.0 ; python_full_version >= "3.12.0" and python_version < "3.13" and (sys_platform != "win32" and sys_platform != "emscripten") or python_full_version >= "3.12.0" and python_version < "3.13" and os_name != "nt"
pure-eval==0.2.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
py==1.11.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
pyarrow==18.1.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -164,7 +162,7 @@ pyflakes==3.2.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
pygments==2.19.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
pynvim==0.5.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
pyparsing==3.2.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
pyproject-api==1.8.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
pyproject-api==1.9.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
pysocks==1.7.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
pytest-cov==6.0.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
pytest-split==0.10.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -174,14 +172,13 @@ python-json-logger==3.2.1 ; python_full_version >= "3.12.0" and python_version <
pytorch-ignite==0.5.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
pytorch-ranger==0.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
pytz==2024.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
pywin32==308 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_full_version >= "3.12.0" and python_version < "3.13"
pywin32==308 ; python_full_version >= "3.12.0" and python_version < "3.13" and sys_platform == "win32" and platform_python_implementation != "PyPy"
pywinpty==2.0.14 ; python_full_version >= "3.12.0" and python_version < "3.13" and os_name == "nt"
pyyaml==6.0.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
pyzmq==26.2.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
referencing==0.36.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
regex==2024.11.6 ; python_full_version >= "3.12.0" and python_version < "3.13"
requests==2.32.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
requests[socks]==2.32.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
requirements-parser==0.11.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
rfc3339-validator==0.1.4 ; python_full_version >= "3.12.0" and python_version < "3.13"
rfc3986-validator==0.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -192,7 +189,7 @@ scikit-learn==1.6.1 ; python_full_version >= "3.12.0" and python_version < "3.13
scipy==1.15.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
seaborn==0.13.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
send2trash==1.8.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
setuptools==75.8.0 ; python_version >= "3.12" and python_version < "3.13"
setuptools==75.8.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
six==1.17.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
snakeviz==2.2.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
sniffio==1.3.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand Down Expand Up @@ -224,16 +221,16 @@ torch-optimizer==0.3.0 ; python_full_version >= "3.12.0" and python_version < "3
torch==2.5.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
torchvision==0.20.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
tornado==6.4.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
tox==4.23.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
tox==4.24.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
tqdm==4.67.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
traitlets==5.14.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
transformers==4.46.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
triton==3.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_full_version >= "3.12.0"
triton==3.1.0 ; python_full_version >= "3.12.0" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
types-python-dateutil==2.9.0.20241206 ; python_full_version >= "3.12.0" and python_version < "3.13"
types-pyyaml==6.0.12.20241230 ; python_full_version >= "3.12.0" and python_version < "3.13"
types-setuptools==75.8.0.20250110 ; python_full_version >= "3.12.0" and python_version < "3.13"
typing-extensions==4.12.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
tzdata==2024.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
tzdata==2025.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
uri-template==1.3.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
urllib3==2.3.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
uvicorn==0.32.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand Down
Loading

0 comments on commit 75d9615

Please sign in to comment.