-
Notifications
You must be signed in to change notification settings - Fork 3
/
st_state_patch.py
219 lines (147 loc) · 5.22 KB
/
st_state_patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""Another prototype of the State implementation.
Usage
-----
How to import this:
import streamlit as st
import st_state_patch
When you do that, you will get 3 new commands in the "st" module:
* st.State
* st.SessionState
* st.GlobalState
The important class here is st.State. The other two are just an alternate API
that provides some syntax sugar.
Using st.State
--------------
Just call st.State() and you'll get a session-specific object to add state into.
To initialize it, just use an "if" block, like this:
s = st.State()
if not s:
# Initialize it here!
s.foo = "bar"
If you want your state to be global rather than session-specific, pass the
"is_global" keyword argument:
s = st.State(is_global=True)
if not s:
# Initialize it here!
s.foo = "bar"
Alternate API
-------------
If you think this reads better, you can create session-specific and global State
objects with these commands instread:
s0 = st.SessionState()
# Same as st.State()
s1 = st.GlobalState()
# Same as st.State(is_global=True)
Multiple states per app
-----------------------
If you'd like to instantiate several State objects in the same app, this will
actually give you 2 different State instances:
s0 = st.State()
s1 = st.State()
print(s0 == s1) # Prints False
If that's not what you want, you can use the "key" argument to specify which
exact State object you want:
s0 = st.State(key="user metadata")
s1 = st.State(key="user metadata")
print(s0 == s1) # Prints True
"""
import inspect
import os
import threading
import collections
from streamlit.server.Server import Server
import streamlit as st
import streamlit.ReportThread as ReportThread
# Normally we'd use a Streamtit module, but I want a module that doesn't live in
# your current working directory (since local modules get removed in between
# runs), and Streamtit devs are likely to have Streamlit in their cwd.
import sys
GLOBAL_CONTAINER = sys
class State(object):
def __new__(cls, key=None, is_global=False):
if is_global:
states_dict, key_counts = _get_global_state()
else:
states_dict, key_counts = _get_session_state()
if key is None:
key = _figure_out_key(key_counts)
if key in states_dict:
return states_dict[key]
state = super(State, cls).__new__(cls)
states_dict[key] = state
return state
def __init__(self, key=None, is_global=False):
pass
def __bool__(self):
return bool(len(self.__dict__))
def __contains__(self, name):
return name in self.__dict__
def _get_global_state():
if not hasattr(GLOBAL_CONTAINER, '_global_state'):
GLOBAL_CONTAINER._global_state = {}
GLOBAL_CONTAINER._key_counts = collections.defaultdict(int)
return GLOBAL_CONTAINER._global_state, GLOBAL_CONTAINER._key_counts
def _get_session_state():
session = _get_session_object()
curr_thread = threading.current_thread()
if not hasattr(session, '_session_state'):
session._session_state = {}
if not hasattr(curr_thread, '_key_counts'):
# Put this in the thread because it gets cleared on every run.
curr_thread._key_counts = collections.defaultdict(int)
return session._session_state, curr_thread._key_counts
def _get_session_object():
# Hack to get the session object from Streamlit.
ctx = ReportThread.get_report_ctx()
this_session = None
current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()
for session_info in session_infos:
s = session_info.session
if (
# Streamlit < 0.54.0
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
or
# Streamlit >= 0.54.0
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
):
this_session = s
if this_session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object"
'Are you doing something fancy with threads?')
return this_session
def _figure_out_key(key_counts):
stack = inspect.stack()
for stack_pos, stack_item in enumerate(stack):
filename = stack_item[1]
if filename != __file__:
break
else:
stack_item = None
if stack_item is None:
return None
# Just breaking these out for readability.
#frame_id = id(stack_item[0])
filename = stack_item[1]
# line_no = stack_item[2]
func_name = stack_item[3]
# code_context = stack_item[4]
key = "%s :: %s :: %s" % (filename, func_name, stack_pos)
count = key_counts[key]
key_counts[key] += 1
key = "%s :: %s" % (key, count)
return key
class SessionState(object):
def __new__(cls, key=None):
return State(key=key, is_global=False)
class GlobalState(object):
def __new__(cls, key=None):
return State(key=key, is_global=True)
st.State = State
st.GlobalState = GlobalState
st.SessionState = SessionState