-
Notifications
You must be signed in to change notification settings - Fork 3
/
clusters_vis.py
35 lines (23 loc) · 996 Bytes
/
clusters_vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pandas as pd
import plotly.express as px
from pyparsing import col
import streamlit as st
from glob import glob
import os.path
files = glob("data/s2orc/clusts_vis/*.parquet")
@st.cache
def load_data(file_name:str):
df = pd.read_parquet(file_name, columns=['x', 'y', 'noun_chunk','cluster_name'])
return df
st.title('Clusters visualization')
file_name= st.selectbox("Select a file", files)
df = load_data(file_name)
st.text(os.path.basename(file_name))
noun_chunks = df['noun_chunk']
clusters_names = df['cluster_name'].unique()
# st.table(df['cluster'].value_counts().sort_values(ascending=False))
# st.table(clusters_names)
# st.plotly_chart(px.scatter(df, x='x', y='y', color='chunk_y', hover_data=['chunk_x'], hover_name='chunk_y'))
clust_names = st.multiselect('Select cluster', clusters_names)
df2 = df[df['cluster_name'].isin(clust_names)]
st.plotly_chart(px.scatter(df2, x='x', y='y', color='cluster_name', hover_data=['noun_chunk'], hover_name='cluster_name'))