Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

838 create tabular tab to dashboard and redesign loaded data results #819 #844

Merged
merged 98 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
12a0ab7
move method, params, prediction, and top to show to container
laurasootes Jul 31, 2024
2e444e4
fix text
laurasootes Aug 1, 2024
3ce30b6
set message before tickboxes
laurasootes Aug 1, 2024
125d0a8
specify that the row label represents a class
laurasootes Aug 1, 2024
911965e
switch rise and lime for consistency with other pages
laurasootes Aug 1, 2024
2957be9
move method_params into method checkboxes function
laurasootes Aug 1, 2024
bf164bb
move method params into method checkboxes to get the params expander …
laurasootes Aug 1, 2024
acc1a32
delete method header since they are per method now
laurasootes Aug 1, 2024
ac85823
make method axis label equal size as class axis label
laurasootes Aug 1, 2024
4d8ef36
move prediction to top ourside of container
laurasootes Aug 1, 2024
6f89ba3
center titles
laurasootes Aug 1, 2024
1918b3a
add some space
laurasootes Aug 1, 2024
4c62ab3
update text label
laurasootes Aug 2, 2024
eb9af6d
make the row elements smaller
laurasootes Aug 2, 2024
7459ed3
add some space
laurasootes Aug 2, 2024
8432fe3
remove sidebar from home page
laurasootes Aug 2, 2024
3b7b5f4
use simple logo add
laurasootes Aug 2, 2024
310da69
move predicted class inside container
laurasootes Aug 2, 2024
ed61dbb
move predicted class inside container and all sidebar logo
laurasootes Aug 2, 2024
b3de953
Merge branch '818-dashboard-modifu-input-data-section-for-multiple-ex…
laurasootes Aug 2, 2024
afc3092
use page wide layout instead of small centered
laurasootes Aug 5, 2024
343f905
Merge branch '790-add-scientific-use-case-frb-to-dashboard' of https:…
laurasootes Aug 6, 2024
af5e7a2
Merge branch '790-add-scientific-use-case-frb-to-dashboard' of https:…
laurasootes Aug 7, 2024
4a3c451
fix typo
laurasootes Aug 7, 2024
80603d4
fix double checkbox
laurasootes Aug 7, 2024
dbdbc67
fix key label
laurasootes Aug 7, 2024
906b6e9
fix popping
laurasootes Aug 8, 2024
eb5d3f4
use f strings
laurasootes Aug 8, 2024
637dae5
remove print statement
laurasootes Aug 8, 2024
84a873b
Merge branch '790-add-scientific-use-case-frb-to-dashboard' of https:…
laurasootes Aug 8, 2024
4d5b04a
initialize tab for tabular
laurasootes Aug 8, 2024
083ebc8
add default exception for tabular
laurasootes Aug 8, 2024
db94c26
Merge branch '819-dashboard-redesign-loaded-data-results' of https://…
laurasootes Aug 8, 2024
a60f3e8
Set default for tabular page
laurasootes Aug 8, 2024
8ac3e51
add tabular page reference
laurasootes Aug 9, 2024
70fb93d
add tabular example setup
laurasootes Aug 9, 2024
f3629bb
add upload menu
laurasootes Aug 9, 2024
ecb6161
fix imports and draft prediction
laurasootes Aug 9, 2024
e494c62
draft model
laurasootes Aug 9, 2024
80c284f
add table visualisation with row selection option
laurasootes Aug 9, 2024
00d9386
add load data for csv to pandas
laurasootes Aug 9, 2024
04574c4
update for non-classifier model
laurasootes Aug 9, 2024
852570e
allow for uploads without labels if not classifier
laurasootes Aug 9, 2024
bd20045
allow for labels to be empty
laurasootes Aug 9, 2024
6714801
reset
laurasootes Aug 9, 2024
8886611
show prediction
laurasootes Aug 9, 2024
5e0727d
temp for weather data
laurasootes Aug 9, 2024
78d8826
only if input is given
laurasootes Aug 12, 2024
891ef84
move info message
laurasootes Aug 12, 2024
a8bcbb2
change variable names
laurasootes Aug 12, 2024
5b03fe9
model prediction for tabular
laurasootes Aug 12, 2024
5d64e69
use iloc
laurasootes Aug 12, 2024
5ba7624
import dispatcher
laurasootes Aug 12, 2024
4d61118
update for tabular
laurasootes Aug 12, 2024
05a8d37
fix location of serialization
laurasootes Aug 12, 2024
6f2ed43
add import
laurasootes Aug 12, 2024
615f1f5
set empty string if not a classifier
laurasootes Aug 12, 2024
91a04af
tmp
laurasootes Aug 13, 2024
a58d472
add required data type
laurasootes Aug 13, 2024
c352994
load training data
laurasootes Aug 13, 2024
924b38f
delete unused import
laurasootes Aug 13, 2024
880f0db
convert to flaot for explainer
laurasootes Aug 13, 2024
1007464
provide correct input
laurasootes Aug 13, 2024
7be7045
add break point
laurasootes Aug 13, 2024
aa16a21
add mode
laurasootes Aug 13, 2024
87fd6af
add training data argument
laurasootes Aug 13, 2024
0de7c9f
add load training data function
laurasootes Aug 13, 2024
66e9079
plot relevances results
laurasootes Aug 13, 2024
004023b
only show predicted class if classification model
laurasootes Aug 13, 2024
924472c
only add feature names kwarg for LIME
laurasootes Aug 14, 2024
3ac3949
add comment for dashboard
laurasootes Aug 14, 2024
be49476
make sure data is floats
laurasootes Aug 21, 2024
b108ca4
add feature names
laurasootes Aug 21, 2024
b14e928
convert to float32
laurasootes Aug 21, 2024
20f0135
fix keyword arguments for tabular
laurasootes Aug 22, 2024
ef5308c
no kernelshap
laurasootes Aug 22, 2024
556aa5f
kernelshap function for tabular
laurasootes Aug 22, 2024
125cc16
update text
laurasootes Aug 22, 2024
d24075e
undo weather example specifics
laurasootes Aug 22, 2024
b1133fb
add test for tabular
laurasootes Aug 22, 2024
40e4011
Merge branch 'main' of https://github.com/dianna-ai/dianna into 838-c…
laurasootes Aug 22, 2024
38e1d70
fix merge issue
laurasootes Aug 22, 2024
ba1e296
fix typo
laurasootes Aug 22, 2024
fb497b9
fix input conversion
laurasootes Aug 22, 2024
e1b2eb2
undo import
laurasootes Aug 22, 2024
c53bd96
delete unused imports
laurasootes Aug 22, 2024
ddef7bf
delete print statement
laurasootes Aug 22, 2024
d1c5d1c
check for both positive and negatice
laurasootes Aug 22, 2024
3707fd9
test for two classes
laurasootes Aug 22, 2024
37d5a5d
fix tests
laurasootes Aug 22, 2024
3d5c64f
make ruff happy
laurasootes Aug 22, 2024
123e2ab
increase timeout
laurasootes Aug 23, 2024
7f75520
fix linting
laurasootes Aug 23, 2024
b2c3590
fix linting
laurasootes Aug 23, 2024
5d8f196
fix sorting
laurasootes Aug 23, 2024
3fd787d
fix imports
laurasootes Aug 23, 2024
ce38ea9
increase timeout?
laurasootes Aug 23, 2024
f38fd09
increase timeouts
laurasootes Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions dianna/dashboard/Home.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import importlib
import streamlit as st
from _shared import add_sidebar_logo
from _shared import data_directory
from streamlit_option_menu import option_menu

st.set_page_config(page_title="Dianna's dashboard",
page_icon='📊',
layout='centered',
layout='wide',
initial_sidebar_state='auto',
menu_items={
'Get help':
Expand All @@ -22,6 +21,7 @@
pages = {
"Home": "home",
"Images": "pages.Images",
"Tabular": "pages.Tabular",
"Text": "pages.Text",
"Time series": "pages.Time_series"
}
Expand All @@ -30,16 +30,14 @@
selected = option_menu(
menu_title=None,
options=list(pages.keys()),
icons=["house", "camera", "alphabet", "clock"],
icons=["house", "camera", "table", "alphabet", "clock"],
menu_icon="cast",
default_index=0,
orientation="horizontal"
)

# Display the content of the selected page
if selected == "Home":
add_sidebar_logo()

st.image(str(data_directory / 'logo.png'))

st.markdown("""
Expand All @@ -50,9 +48,10 @@

### Pages

- <a href="/Images" target="_parent">Images</a>
- <a href="/Text" target="_parent">Text</a>
- <a href="/Time_series" target="_parent">Time series</a>
- <a href="/Images" target="_parent">Image data</a>
- <a href="/Tabular" target="_parent">Tabular data</a>
- <a href="/Text" target="_parent">Text data</a>
- <a href="/Time_series" target="_parent">Time series data</a>


### More information
Expand All @@ -70,6 +69,10 @@
for k in st.session_state.keys():
if 'Image' in k:
st.session_state.pop(k, None)
if selected != 'Tabular':
for k in st.session_state.keys():
if 'Tabular' in k:
st.session_state.pop(k, None)
if selected != 'Text':
for k in st.session_state.keys():
if 'Text' in k:
Expand Down
13 changes: 13 additions & 0 deletions dianna/dashboard/_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from pathlib import Path
import numpy as np
import onnx
import pandas as pd


def load_data(file):
"""Open data from a file and returns it as pandas DataFrame."""
df = pd.read_csv(file, parse_dates=True)
# Add index column
df.insert(0, 'Index', df.index)
return df


def preprocess_function(image):
Expand Down Expand Up @@ -29,3 +38,7 @@ def load_labels(file):
if labels is None or labels == ['']:
raise ValueError(labels)
return labels


def load_training_data(file):
return np.float32(np.load(file, allow_pickle=False))
57 changes: 57 additions & 0 deletions dianna/dashboard/_models_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tempfile
import numpy as np
import streamlit as st
from dianna import explain_tabular
from dianna.utils.onnx_runner import SimpleModelRunner


@st.cache_data
def predict(*, model, tabular_input):
model_runner = SimpleModelRunner(model)
predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32))
return predictions


@st.cache_data
def _run_rise_tabular(_model, table, training_data, **kwargs):
relevances = explain_tabular(
_model,
table,
method='RISE',
training_data=training_data,
**kwargs,
)
return relevances


@st.cache_data
def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
relevances = explain_tabular(
_model,
table,
method='LIME',
training_data=training_data,
feature_names=_feature_names,
**kwargs,
)
return relevances

@st.cache_data
def _run_kernelshap_tabular(model, table, training_data, **kwargs):
# Kernelshap interface is different. Write model to temporary file.
with tempfile.NamedTemporaryFile() as f:
f.write(model)
f.flush()
relevances = explain_tabular(f.name,
table,
method='KernelSHAP',
training_data=training_data,
**kwargs)
return relevances[0]


explain_tabular_dispatcher = {
'RISE': _run_rise_tabular,
'LIME': _run_lime_tabular,
'KernelSHAP': _run_kernelshap_tabular
}
96 changes: 48 additions & 48 deletions dianna/dashboard/_shared.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import base64
import sys
from typing import Any
from typing import Dict
from typing import Sequence
import numpy as np
import streamlit as st
Expand Down Expand Up @@ -46,71 +44,67 @@ def build_markup_for_logo(


def add_sidebar_logo():
"""Based on: https://stackoverflow.com/a/73278825."""
png_file = data_directory / 'logo.png'
logo_markup = build_markup_for_logo(png_file)
st.markdown(
logo_markup,
unsafe_allow_html=True,
)
"""Upload DIANNA logo to sidebar element."""
st.sidebar.image(str(data_directory / 'logo.png'))


def _methods_checkboxes(*, choices: Sequence, key):
"""Get methods from a horizontal row of checkboxes."""
"""Get methods from a horizontal row of checkboxes and the corresponding parameters."""
n_choices = len(choices)
methods = []
method_params = {}

# Create a container for the message
message_container = st.empty()

for col, method in zip(st.columns(n_choices), choices):
with col:
if st.checkbox(method, key=key + method):
if st.checkbox(method, key=f'{key}_{method}'):
methods.append(method)
with st.expander(f'Click to modify {method} parameters'):
method_params[method] = _get_params(method, key=f'{key}_param')

if not methods:
st.info('Select a method to continue')
# Put the message in the container above
message_container.info('Select a method to continue')
st.stop()

return methods
return methods, method_params


def _get_params(method: str, key):
if method == 'RISE':
return {
'n_masks':
st.number_input('Number of masks', value=1000, key=key + method + 'nmasks'),
st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'),
'feature_res':
st.number_input('Feature resolution', value=6, key=key + method + 'fr'),
st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'),
'p_keep':
st.number_input('Probability to be kept unmasked', value=0.1, key=key + method + 'pkeep'),
st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'),
}

elif method == 'KernelSHAP':
return {
'nsamples': st.number_input('Number of samples', value=1000, key=key + method + 'nsamp'),
'background': st.number_input('Background', value=0, key=key + method + 'background'),
'n_segments': st.number_input('Number of segments', value=200, key=key + method + 'nseg'),
'sigma': st.number_input('σ', value=0, key=key + method + 'sigma'),
}
if 'Tabular' in key:
return {'training_data_kmeans': st.number_input('Training data kmeans', value=5,
key=f'{key}_{method}_training_data_kmeans'),
}
else:
return {
'nsamples': st.number_input('Number of samples', value=1000, key=f'{key}_{method}_nsamp'),
'background': st.number_input('Background', value=0, key=f'{key}_{method}_background'),
'n_segments': st.number_input('Number of segments', value=200, key=f'{key}_{method}_nseg'),
'sigma': st.number_input('σ', value=0, key=f'{key}_{method}_sigma'),
}

elif method == 'LIME':
return {
'random_state': st.number_input('Random state', value=2, key=key + method + 'rs'),
'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
}

else:
raise ValueError(f'No such method: {method}')


def _get_method_params(methods: Sequence[str], key) -> Dict[str, Dict[str, Any]]:
method_params = {}

with st.expander('Click to modify method parameters'):
for method, col in zip(methods, st.columns(len(methods))):
with col:
st.header(method)
method_params[method] = _get_params(method, key=key)

return method_params


def _get_top_indices(predictions, n_top):
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:])
indices = indices[np.argsort(predictions[indices])]
Expand All @@ -119,29 +113,35 @@ def _get_top_indices(predictions, n_top):


def _get_top_indices_and_labels(*, predictions, labels):
c1, c2 = st.columns(2)
cols = st.columns(4)

with c2:
n_top = st.number_input('Number of top results to show',
value=2,
min_value=1,
max_value=len(labels))
if labels is not None:
with cols[-1]:
n_top = st.number_input('Number of top classes to show',
value=1,
min_value=1,
max_value=len(labels))

top_indices = _get_top_indices(predictions, n_top)
top_labels = [labels[i] for i in top_indices]
top_indices = _get_top_indices(predictions, n_top)
top_labels = [labels[i] for i in top_indices]

with c1:
st.metric('Predicted class', top_labels[0])
with cols[0]:
st.metric('Predicted class:', top_labels[0])
else:
# If not a classifier, only return the predicted value
top_indices = top_labels = " "
with cols[0]:
st.metric('Predicted value:', f"{predictions[0]:.2f}")

return top_indices, top_labels

def reset_method():
# Clear selection
for k in st.session_state.keys():
if '_cb_' in k:
st.session_state[k] = False
if 'params' in k:
if '_param' in k:
st.session_state.pop(k)
elif '_cb' in k:
st.session_state[k] = False

def reset_example():
# Clear selection
Expand Down
25 changes: 16 additions & 9 deletions dianna/dashboard/pages/Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from _model_utils import load_model
from _models_image import explain_image_dispatcher
from _models_image import predict
from _shared import _get_method_params
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
Expand Down Expand Up @@ -88,15 +87,23 @@
labels = load_labels(image_label_file)

choices = ('RISE', 'KernelSHAP', 'LIME')
methods = _methods_checkboxes(choices=choices, key='Image_cb_')

method_params = _get_method_params(methods, key='Image_params_')
st.text("")
st.text("")

with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)
with st.container(border=True):
prediction_placeholder = st.empty()
methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb')

top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions,
labels=labels)
with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)

with prediction_placeholder:
top_indices, top_labels = _get_top_indices_and_labels(
predictions=predictions,labels=labels)

st.text("")
st.text("")

# check which axis is color channel
original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :]
Expand All @@ -107,11 +114,11 @@

_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
col.header(method)
col.markdown(f"<h4 style='text-align: center; '>{method}</h4>", unsafe_allow_html=True)

for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)
index_col.markdown(f'##### {label}')
index_col.markdown(f'##### Class: {label}')

for col, method in zip(columns, methods):
kwargs = method_params[method].copy()
Expand Down
Loading
Loading