This repository has been archived by the owner on Mar 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy path_env.py
127 lines (105 loc) · 4.45 KB
/
_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Author: kingjr, 2020
import contextlib
import multiprocessing
import logging
import typing as tp
from pathlib import Path
import yaml
from .utils import identify_host
logger = logging.getLogger(__name__)
class Env:
"""Global environment variable providing study and cache paths if available
This is called as bm.env
"""
_instance: tp.Optional["Env"] = None
def __new__(cls) -> "Env":
"""Singleton pattern"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
self._studies: tp.Dict[str, Path] = self.study_default_paths()
self.cache: tp.Optional[Path] = None # cache for precomputation
# models used to create features (Eg: word embeddings)
self.feature_models: tp.Optional[Path] = None
# Hijacking this part of the code as it is one of the first
# to be executed, so a great place to set the start method.
try:
multiprocessing.set_start_method('fork')
except RuntimeError:
logger.warning("Could not set start method, cache might not work properly.")
@staticmethod
def _get_host_study_paths(all_study_paths) -> tp.Dict[str, str]:
"""Get study paths for the current host.
"""
hostname = identify_host()
logger.debug(f'Identified host {hostname}.')
study_paths = all_study_paths.get(hostname)
if study_paths is None: # Use default paths
logger.warning(
f'Hostname {hostname} not defined in '
'/conf/study_paths/study_paths.yaml. Using default paths.')
study_paths = all_study_paths['default']
return study_paths
@classmethod
def study_default_paths(cls) -> tp.Dict[str, Path]:
"""Fills the study paths with their default value in study_paths.yaml"""
fp = Path(__file__).parent / "conf" / "study_paths" / "study_paths.yaml"
assert fp.exists()
with fp.open() as f:
content = yaml.safe_load(f)
logger.debug(content)
study_paths = cls._get_host_study_paths(content)
return {x: Path(y) for x, y in study_paths.items() if Path(y).exists()}
@contextlib.contextmanager
def temporary_from_args(self, args: tp.Any, wipe_studies: bool = False) -> tp.Iterator[None]:
"""Update cache, features and study paths
Parameters
----------
wipe_studies: if True, the studies paths currently in the env will be wiped, if False
only the specified keys will override the current paths.
"""
kwargs: tp.Dict[str, tp.Any] = dict(studies={} if wipe_studies else self.studies)
for name, val in args.items():
if val is not None:
if name in ('cache', 'feature_models'):
kwargs[name] = Path(val)
elif name == "study_paths" and val is not None:
study_paths = self._get_host_study_paths(val)
kwargs["studies"].update(
{x: Path(y) for x, y in study_paths.items()})
with self.temporary(**kwargs):
yield
@property
def studies(self) -> tp.Dict[str, Path]:
return dict(self._studies)
@studies.setter
def studies(self, paths: tp.Dict[str, tp.Union[str, Path]]) -> None:
self._studies = {name: Path(path) for name, path in paths.items()}
@contextlib.contextmanager
def temporary(self, **kwargs: tp.Any) -> tp.Iterator[None]:
"""Temporarily replaces a path by the provided one
for the duration of the "with" context.
"""
currents: tp.Dict[str, tp.Any] = {}
for key, val in kwargs.items():
if isinstance(val, str):
val = Path(val)
currents[key] = getattr(self, key)
setattr(self, key, val)
try:
yield
finally:
for key, val in currents.items():
setattr(self, key, val)
def __repr__(self) -> str:
vals = {k: x for k, x in self.__dict__.items() if not k.startswith("_")}
vals["studies"] = self._studies
string = ",".join(f"{x}={y}" for x, y in sorted(vals.items()))
return f"{self.__class__.__name__}({string})"
env = Env()