-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
183 lines (164 loc) · 6.89 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from river import stream
from river import cluster
class River:
def __init__(self, model):
self.model = model
def partial_fit(self, umap_embeddings):
for umap_embedding, _ in stream.iter_array(umap_embeddings):
self.model = self.model.learn_one(umap_embedding)
labels = []
for umap_embedding, _ in stream.iter_array(umap_embeddings):
label = self.model.predict_one(umap_embedding)
labels.append(label)
self.labels_ = labels
return self
import pandas as pd
from typing import List
import plotly.graph_objects as go
from sklearn.preprocessing import normalize
def visualize_topics_over_time(topic_model,
topics_over_time: pd.DataFrame,
top_n_topics: int = None,
topics: List[int] = None,
normalize_frequency: bool = False,
custom_labels: bool = False,
title: str = "<b>Topics over Time</b>",
width: int = 860,
height: int = 600) -> go.Figure:
"""
Based on BERTopic's funciton https://github.com/MaartenGr/BERTopic/blob/809414b88ca3f12a46728069d098d82345986489/bertopic/plotting/_topics_over_time.py
"""
#colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"]
# Select topics based on top_n and topics args
freq_df = topic_model.get_topic_freq()
freq_df = freq_df.loc[freq_df.Topic != -1, :]
if topics is not None:
selected_topics = list(topics)
elif top_n_topics is not None:
selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
else:
selected_topics = sorted(freq_df.Topic.to_list())
# Prepare data
if topic_model.custom_labels_ is not None and custom_labels:
topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
else:
topic_names = {key: value[:30] + "..." if len(value) > 30 else value
for key, value in topic_model.topic_labels_.items()}
topics_over_time["Name"] = topics_over_time.Topic.map(topic_names)
data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"])
# Add traces
fig = go.Figure()
for index, topic in enumerate(data.Topic.unique()):
trace_data = data.loc[data.Topic == topic, :]
topic_name = trace_data.Name.values[0]
words = trace_data.Words.values
if normalize_frequency:
y = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
else:
y = trace_data.Frequency
fig.add_trace(go.Scatter(x=pd.to_datetime(trace_data.Timestamp), y=y,
mode='lines',
#marker_color=colors[index % 7],
hoverinfo="text",
name=topic_name,
hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
# Styling of the visualization
#fig.update_xaxes(
# dtick=7,
# tickformat="%b\n%Y"
# )
fig.update_layout(
yaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
title={'text':f'{title}',
'font': dict(size=22)
},
width=width,
height=height,
hoverlabel=dict(
bgcolor="white",
font_size=16,
#font_family="Rockwell"
),
legend=dict(
title="<b>Global Topic Representation",
orientation="h",
y = -.2,
x = 0
#yanchor="bottom",
#xanchor="left"
)
)
return fig
def visualize_topics_per_class(topic_model,
topics_per_class: pd.DataFrame,
top_n_topics: int = 10,
topics: List[int] = None,
normalize_frequency: bool = False,
custom_labels: bool = False,
title: str = "<b>Topics per Class</b>",
width: int = 900,
height: int = 900) -> go.Figure:
"""
Based on BERTopic's funciton https://github.com/MaartenGr/BERTopic/blob/809414b88ca3f12a46728069d098d82345986489/bertopic/plotting/_topics_per_class.py
"""
# Select topics based on top_n and topics args
freq_df = topic_model.get_topic_freq()
freq_df = freq_df.loc[freq_df.Topic != -1, :]
if topics is not None:
selected_topics = list(topics)
elif top_n_topics is not None:
#selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
selected_topics = freq_df.Topic.to_list()[:top_n_topics]
else:
selected_topics = sorted(freq_df.Topic.to_list())
# Prepare data
if topic_model.custom_labels_ is not None and custom_labels:
topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
else:
topic_names = {key: value[:40] + "..." if len(value) > 40 else value
for key, value in topic_model.topic_labels_.items()}
topics_per_class["Name"] = topics_per_class.Topic.map(topic_names)
data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :]
# Add traces
fig = go.Figure()
for index, topic in enumerate(selected_topics):
if index == 0:
visible = True
else:
visible = "legendonly"
trace_data = data.loc[data.Topic == topic, :]
topic_name = trace_data.Name.values[0]
words = trace_data.Words.values
if normalize_frequency:
x = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
else:
x = trace_data.Frequency
fig.add_trace(go.Bar(y=trace_data.Class,
x=x,
visible=visible,
hoverinfo="text",
name=topic_name,
orientation="h",
hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
# Styling of the visualization
fig.update_xaxes(showgrid=True)
fig.update_yaxes(showgrid=True)
fig.update_layout(
xaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
yaxis_title="Class",
title={
'text': f"{title}",
'font': dict(
size=22)
},
width=width,
height=height,
hoverlabel=dict(
bgcolor="white",
font_size=16,
),
legend=dict(
title="<b>Global Topic Representation",
)
)
return fig