forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_config.py
179 lines (146 loc) · 6.1 KB
/
_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Global configuration state and functions for management
"""
import os
from contextlib import contextmanager as contextmanager
import threading
_global_config = {
"assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
"print_changed_only": True,
"display": "text",
}
_threadlocal = threading.local()
def _get_threadlocal_config():
"""Get a threadlocal **mutable** configuration. If the configuration
does not exist, copy the default global configuration."""
if not hasattr(_threadlocal, "global_config"):
_threadlocal.global_config = _global_config.copy()
return _threadlocal.global_config
def get_config():
"""Retrieve current values for configuration set by :func:`set_config`.
Returns
-------
config : dict
Keys are parameter names that can be passed to :func:`set_config`.
See Also
--------
config_context : Context manager for global scikit-learn configuration.
set_config : Set global scikit-learn configuration.
"""
# Return a copy of the threadlocal configuration so that users will
# not be able to modify the configuration with the returned dict.
return _get_threadlocal_config().copy()
def set_config(
assume_finite=None, working_memory=None, print_changed_only=None, display=None
):
"""Set global scikit-learn configuration
.. versionadded:: 0.19
Parameters
----------
assume_finite : bool, default=None
If True, validation for finiteness will be skipped,
saving time, but leading to potential crashes. If
False, validation for finiteness will be performed,
avoiding error. Global default: False.
.. versionadded:: 0.19
working_memory : int, default=None
If set, scikit-learn will attempt to limit the size of temporary arrays
to this number of MiB (per job when parallelised), often saving both
computation time and memory on expensive operations that can be
performed in chunks. Global default: 1024.
.. versionadded:: 0.20
print_changed_only : bool, default=None
If True, only the parameters that were set to non-default
values will be printed when printing an estimator. For example,
``print(SVC())`` while True will only print 'SVC()' while the default
behaviour would be to print 'SVC(C=1.0, cache_size=200, ...)' with
all the non-changed parameters.
.. versionadded:: 0.21
display : {'text', 'diagram'}, default=None
If 'diagram', estimators will be displayed as a diagram in a Jupyter
lab or notebook context. If 'text', estimators will be displayed as
text. Default is 'text'.
.. versionadded:: 0.23
See Also
--------
config_context : Context manager for global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
"""
local_config = _get_threadlocal_config()
if assume_finite is not None:
local_config["assume_finite"] = assume_finite
if working_memory is not None:
local_config["working_memory"] = working_memory
if print_changed_only is not None:
local_config["print_changed_only"] = print_changed_only
if display is not None:
local_config["display"] = display
@contextmanager
def config_context(
*, assume_finite=None, working_memory=None, print_changed_only=None, display=None
):
"""Context manager for global scikit-learn configuration.
Parameters
----------
assume_finite : bool, default=None
If True, validation for finiteness will be skipped,
saving time, but leading to potential crashes. If
False, validation for finiteness will be performed,
avoiding error. If None, the existing value won't change.
The default value is False.
working_memory : int, default=None
If set, scikit-learn will attempt to limit the size of temporary arrays
to this number of MiB (per job when parallelised), often saving both
computation time and memory on expensive operations that can be
performed in chunks. If None, the existing value won't change.
The default value is 1024.
print_changed_only : bool, default=None
If True, only the parameters that were set to non-default
values will be printed when printing an estimator. For example,
``print(SVC())`` while True will only print 'SVC()', but would print
'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters
when False. If None, the existing value won't change.
The default value is True.
.. versionchanged:: 0.23
Default changed from False to True.
display : {'text', 'diagram'}, default=None
If 'diagram', estimators will be displayed as a diagram in a Jupyter
lab or notebook context. If 'text', estimators will be displayed as
text. If None, the existing value won't change.
The default value is 'text'.
.. versionadded:: 0.23
Yields
------
None.
See Also
--------
set_config : Set global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
Notes
-----
All settings, not just those presently modified, will be returned to
their previous values when the context manager is exited.
Examples
--------
>>> import sklearn
>>> from sklearn.utils.validation import assert_all_finite
>>> with sklearn.config_context(assume_finite=True):
... assert_all_finite([float('nan')])
>>> with sklearn.config_context(assume_finite=True):
... with sklearn.config_context(assume_finite=False):
... assert_all_finite([float('nan')])
Traceback (most recent call last):
...
ValueError: Input contains NaN...
"""
old_config = get_config()
set_config(
assume_finite=assume_finite,
working_memory=working_memory,
print_changed_only=print_changed_only,
display=display,
)
try:
yield
finally:
set_config(**old_config)