-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathsession.py
89 lines (69 loc) · 2.77 KB
/
session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#
# Code for managing session state, which is needed for multi-input forms
# See https://github.com/streamlit/streamlit/issues/1557
#
# This code is taken from
# https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
#
from streamlit.hashing import _CodeHasher
from streamlit.report_thread import get_report_ctx
from streamlit.server.server import Server
class _SessionState:
def __init__(self, session, hash_funcs):
"""Initialize SessionState instance."""
self.__dict__["_state"] = {
"data": {},
"hash": None,
"hasher": _CodeHasher(hash_funcs),
"is_rerun": False,
"session": session,
}
def __call__(self, **kwargs):
"""Initialize state data once."""
for item, value in kwargs.items():
if item not in self._state["data"]:
self._state["data"][item] = value
def __getitem__(self, item):
"""Return a saved state value, None if item is undefined."""
return self._state["data"].get(item, None)
def __getattr__(self, item):
"""Return a saved state value, None if item is undefined."""
return self._state["data"].get(item, None)
def __setitem__(self, item, value):
"""Set state value."""
self._state["data"][item] = value
def __setattr__(self, item, value):
"""Set state value."""
self._state["data"][item] = value
def clear(self):
"""Clear session state and request a rerun."""
self._state["data"].clear()
self._state["session"].request_rerun()
def sync(self):
"""
Rerun the app with all state values up to date from the beginning to
fix rollbacks.
"""
data_to_bytes = self._state["hasher"].to_bytes(self._state["data"], None)
# Ensure to rerun only once to avoid infinite loops
# caused by a constantly changing state value at each run.
#
# Example: state.value += 1
if self._state["is_rerun"]:
self._state["is_rerun"] = False
elif self._state["hash"] is not None:
if self._state["hash"] != data_to_bytes:
self._state["is_rerun"] = True
self._state["session"].request_rerun()
self._state["hash"] = data_to_bytes
def _get_session():
session_id = get_report_ctx().session_id
session_info = Server.get_current()._get_session_info(session_id)
if session_info is None:
raise RuntimeError("Couldn't get your Streamlit Session object.")
return session_info.session
def _get_state(hash_funcs=None):
session = _get_session()
if not hasattr(session, "_custom_session_state"):
session._custom_session_state = _SessionState(session, hash_funcs)
return session._custom_session_state