-
Notifications
You must be signed in to change notification settings - Fork 2
/
streamlit_app.py
135 lines (104 loc) · 3.8 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
from src.preprocess import explode_reviews, preprocess_data
from src.embeddings import embed_reviews, reduce_dimensions_append_array
from src.extract_topic import summarize_sequential
from src.cluster import cluster_and_append, find_closest_to_centroid
from src.visualize import visualize_embeddings, plot_over_time
from src.ui import radio_filter, range_filter
REVIEW_COL = "review_text"
def select_reviews_of_type(df, review_type):
if review_type == "Likes":
return df[["id", "likes"]].rename(columns={"likes": REVIEW_COL})
elif review_type == "Dislikes":
return df[["id", "dislikes"]].rename(columns={"dislikes": REVIEW_COL})
elif review_type == "Use-case":
return df[["id", "usecase"]].rename(columns={"usecase": REVIEW_COL})
else:
raise ValueError("Unexpected review type")
df_cleaned = preprocess_data("./data/g2_reviews.json")
base_df = df_cleaned[
[
"id",
"url",
"product.slug",
"name",
"type",
"helpful",
"score",
"segment",
"role",
"title",
"source.type",
"country",
"region",
"date_submitted",
"date_published",
"industry",
]
]
# Set page to wide mode
st.set_page_config(layout="wide")
sb = st.sidebar
## Select a company
company_counts = base_df["product.slug"].value_counts()
companies_with_counts = {
f"{company} ({count})": company for company, count in company_counts.items()
}
selected_company_label = sb.selectbox("Company", companies_with_counts.keys())
selected_company = companies_with_counts[selected_company_label]
## Select a review type
review_type = sb.radio("Review Type", ["Likes", "Dislikes", "Use-case"])
df_of_type = select_reviews_of_type(df_cleaned, review_type)
# Explode the sentences of that review type
with st.spinner("Parsing review sentences..."):
xpl_df = explode_reviews(df_of_type, REVIEW_COL)
# Embed reviews
with st.spinner("Vectorizing Reviews..."):
embedded_df = embed_reviews(xpl_df, REVIEW_COL)
# Filter to selected company
company_df = base_df[base_df["product.slug"] == selected_company].merge(
embedded_df, on="id"
)
with st.spinner("Clustering Reviews..."):
clustered_df = cluster_and_append(company_df, f"{REVIEW_COL}_embeddings")
NUM_REVIEWS_TO_USE_IN_CLUSTER_LABEL = 30
top_cluster_docs = find_closest_to_centroid(
clustered_df,
NUM_REVIEWS_TO_USE_IN_CLUSTER_LABEL,
f"{REVIEW_COL}_embeddings",
f"{REVIEW_COL}_embeddings_cluster_id",
REVIEW_COL,
)
top_cluster_docs = summarize_sequential(top_cluster_docs, review_type)
top_cluster_map = {
cluster_id: data["cluster_label"] for cluster_id, data in top_cluster_docs.items()
}
clustered_df["cluster_label"] = clustered_df[f"{REVIEW_COL}_embeddings_cluster_id"].map(
top_cluster_map
)
## Reduce the embedding space to 2D for visualization
reduce_dim_df = reduce_dimensions_append_array(
clustered_df, f"{REVIEW_COL}_embeddings", num_dimensions=2, dim_col_name="dims_2d"
)
#### FILTERS
filtered_df = radio_filter("Source", sb, reduce_dim_df, "source.type")
filtered_df = radio_filter("Segment", sb, filtered_df, "segment")
filtered_df = range_filter("Review Date", sb, filtered_df, "date_published")
### Colour Selector
colour_by_selected = st.radio(
"Colour by", options=["Cluster", "Segment", "Source"], index=0, horizontal=True
)
colour_by_col = {
"Segment": "segment",
"Source": "source.type",
"Cluster": "cluster_label",
}[colour_by_selected]
fig_clusters = visualize_embeddings(
filtered_df,
coords_col="dims_2d",
review_text_column=REVIEW_COL,
colour_by_column=colour_by_col,
)
st.plotly_chart(fig_clusters, use_container_width=True)
fig_publish_dates = plot_over_time(filtered_df, "date_published")
st.plotly_chart(fig_publish_dates, use_container_width=True)