From 12a0ab71d566799cfd13b0f471c494c1633f4f00 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 31 Jul 2024 22:40:07 +0200 Subject: [PATCH 01/92] move method, params, prediction, and top to show to container --- dianna/dashboard/pages/Images.py | 14 ++++++++------ dianna/dashboard/pages/Text.py | 16 +++++++++------- dianna/dashboard/pages/Time_series.py | 14 ++++++++------ 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 4a9954d6..be6bb066 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -87,15 +87,17 @@ labels = load_labels(image_label_file) choices = ('RISE', 'KernelSHAP', 'LIME') -methods = _methods_checkboxes(choices=choices, key='Image_cb_') -method_params = _get_method_params(methods, key='Image_params_') +with st.container(border=True): + methods = _methods_checkboxes(choices=choices, key='Image_cb_') -with st.spinner('Predicting class'): - predictions = predict(model=model, image=image) + method_params = _get_method_params(methods, key='Image_params_') -top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions, - labels=labels) + with st.spinner('Predicting class'): + predictions = predict(model=model, image=image) + + top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions, + labels=labels) # check which axis is color channel original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :] diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index c8758d10..8511d077 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -78,17 +78,19 @@ labels = load_labels(text_label_file) choices = ('RISE', 'LIME') -methods = _methods_checkboxes(choices=choices, key='Text_cb_') -method_params = _get_method_params(methods, key='Text_params_') +with st.container(border=True): + methods = _methods_checkboxes(choices=choices, key='Text_cb_') -model_runner = MovieReviewsModelRunner(serialized_model) + method_params = _get_method_params(methods, key='Text_params_') -with st.spinner('Predicting class'): - predictions = predict(model=serialized_model, text_input=text_input) + model_runner = MovieReviewsModelRunner(serialized_model) -top_indices, top_labels = _get_top_indices_and_labels( - predictions=predictions[0], labels=labels) + with st.spinner('Predicting class'): + predictions = predict(model=serialized_model, text_input=text_input) + + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions[0], labels=labels) weight = 0.85 / len(methods) column_spec = [0.15, *[weight for _ in methods]] diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 133cebb4..a7d6f5f5 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -113,15 +113,17 @@ def preprocess(data): labels = load_labels(ts_label_file) choices = ('LIME', 'RISE') -methods = _methods_checkboxes(choices=choices, key='TS_cb_') -method_params = _get_method_params(methods, key='TS_params_') +with st.container(border=True): + methods = _methods_checkboxes(choices=choices, key='TS_cb_') -with st.spinner('Predicting class'): - predictions = predict(model=serialized_model, ts_data=ts_data_model) + method_params = _get_method_params(methods, key='TS_params_') -top_indices, top_labels = _get_top_indices_and_labels( - predictions=predictions[0], labels=labels) + with st.spinner('Predicting class'): + predictions = predict(model=serialized_model, ts_data=ts_data_model) + + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions[0], labels=labels) weight = 0.9 / len(methods) column_spec = [0.1, *[weight for _ in methods]] From 2e444e4680ea4746ae4c1310a2b7af28109840cb Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 14:19:06 +0200 Subject: [PATCH 02/92] fix text --- dianna/dashboard/_shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index a0099e5f..1ad29052 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -122,7 +122,7 @@ def _get_top_indices_and_labels(*, predictions, labels): c1, c2 = st.columns(2) with c2: - n_top = st.number_input('Number of top results to show', + n_top = st.number_input('Number of top classes to show', value=2, min_value=1, max_value=len(labels)) From 3ce30b6d620bb9184ca7ad137224f4be92f4367f Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 14:44:24 +0200 Subject: [PATCH 03/92] set message before tickboxes --- dianna/dashboard/_shared.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 1ad29052..e46acdb3 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -59,13 +59,18 @@ def _methods_checkboxes(*, choices: Sequence, key): """Get methods from a horizontal row of checkboxes.""" n_choices = len(choices) methods = [] + + # Create a container for the message + message_container = st.empty() + for col, method in zip(st.columns(n_choices), choices): with col: if st.checkbox(method, key=key + method): methods.append(method) if not methods: - st.info('Select a method to continue') + # Pu the message in the container above + message_container.info('Select a method to continue') st.stop() return methods From 125d0a8fa839242193091b6e223da019f628f27d Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 15:00:37 +0200 Subject: [PATCH 04/92] specify that the row label represents a class --- dianna/dashboard/pages/Images.py | 2 +- dianna/dashboard/pages/Text.py | 2 +- dianna/dashboard/pages/Time_series.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index be6bb066..0caf06e7 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -112,7 +112,7 @@ for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) - index_col.markdown(f'##### {label}') + index_col.markdown(f'##### Class: {label}') for col, method in zip(columns, methods): kwargs = method_params[method].copy() diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index 8511d077..1719f38f 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -102,7 +102,7 @@ for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) - index_col.markdown(f'##### {label}') + index_col.markdown(f'##### Class: {label}') for col, method in zip(columns, methods): kwargs = method_params[method].copy() diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index a7d6f5f5..7a6db190 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -134,7 +134,7 @@ def preprocess(data): for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) - index_col.markdown(f'##### {label}') + index_col.markdown(f'##### Class: {label}') for col, method in zip(columns, methods): kwargs = method_params[method].copy() From 911965e3df8b1e3e8ada273fe9db8943e5f3982c Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 15:02:16 +0200 Subject: [PATCH 05/92] switch rise and lime for consistency with other pages --- dianna/dashboard/pages/Time_series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 7a6db190..a125b7de 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -112,7 +112,7 @@ def preprocess(data): labels = load_labels(ts_label_file) -choices = ('LIME', 'RISE') +choices = ('RISE', 'LIME') with st.container(border=True): methods = _methods_checkboxes(choices=choices, key='TS_cb_') From 2957be91215a2fcee4ef2676154aac2c8802cef1 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 16:19:49 +0200 Subject: [PATCH 06/92] move method_params into method checkboxes function --- dianna/dashboard/pages/Images.py | 5 +---- dianna/dashboard/pages/Text.py | 5 +---- dianna/dashboard/pages/Time_series.py | 5 +---- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 0caf06e7..230b5995 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -4,7 +4,6 @@ from _model_utils import load_model from _models_image import explain_image_dispatcher from _models_image import predict -from _shared import _get_method_params from _shared import _get_top_indices_and_labels from _shared import _methods_checkboxes from _shared import add_sidebar_logo @@ -89,9 +88,7 @@ choices = ('RISE', 'KernelSHAP', 'LIME') with st.container(border=True): - methods = _methods_checkboxes(choices=choices, key='Image_cb_') - - method_params = _get_method_params(methods, key='Image_params_') + methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb_') with st.spinner('Predicting class'): predictions = predict(model=model, image=image) diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index 1719f38f..4ca25df4 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -4,7 +4,6 @@ from _models_text import explain_text_dispatcher from _models_text import predict from _movie_model import MovieReviewsModelRunner -from _shared import _get_method_params from _shared import _get_top_indices_and_labels from _shared import _methods_checkboxes from _shared import add_sidebar_logo @@ -80,9 +79,7 @@ choices = ('RISE', 'LIME') with st.container(border=True): - methods = _methods_checkboxes(choices=choices, key='Text_cb_') - - method_params = _get_method_params(methods, key='Text_params_') + methods, method_params = _methods_checkboxes(choices=choices, key='Text_cb_') model_runner = MovieReviewsModelRunner(serialized_model) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index a125b7de..c6779a1c 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -3,7 +3,6 @@ from _model_utils import load_model from _models_ts import explain_ts_dispatcher from _models_ts import predict -from _shared import _get_method_params from _shared import _get_top_indices_and_labels from _shared import _methods_checkboxes from _shared import add_sidebar_logo @@ -115,9 +114,7 @@ def preprocess(data): choices = ('RISE', 'LIME') with st.container(border=True): - methods = _methods_checkboxes(choices=choices, key='TS_cb_') - - method_params = _get_method_params(methods, key='TS_params_') + methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb_') with st.spinner('Predicting class'): predictions = predict(model=serialized_model, ts_data=ts_data_model) From bf164bb696dd432bdb48b977dbaf51dc008cb142 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 16:20:31 +0200 Subject: [PATCH 07/92] move method params into method checkboxes to get the params expander per method --- dianna/dashboard/_shared.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index e46acdb3..a5ae0bd6 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -56,10 +56,12 @@ def add_sidebar_logo(): def _methods_checkboxes(*, choices: Sequence, key): - """Get methods from a horizontal row of checkboxes.""" + """Get methods from a horizontal row of checkboxes + and the corresponding parameters.""" n_choices = len(choices) methods = [] - + method_params = {} + # Create a container for the message message_container = st.empty() @@ -67,13 +69,15 @@ def _methods_checkboxes(*, choices: Sequence, key): with col: if st.checkbox(method, key=key + method): methods.append(method) + with st.expander(f'Click to modify {method} parameters'): + st.header(method) + method_params[method] = _get_params(method, key=key) if not methods: - # Pu the message in the container above + # Put the message in the container above message_container.info('Select a method to continue') st.stop() - - return methods + return methods, method_params def _get_params(method: str, key): @@ -104,18 +108,6 @@ def _get_params(method: str, key): raise ValueError(f'No such method: {method}') -def _get_method_params(methods: Sequence[str], key) -> Dict[str, Dict[str, Any]]: - method_params = {} - - with st.expander('Click to modify method parameters'): - for method, col in zip(methods, st.columns(len(methods))): - with col: - st.header(method) - method_params[method] = _get_params(method, key=key) - - return method_params - - def _get_top_indices(predictions, n_top): indices = np.array(np.argpartition(predictions, -n_top)[-n_top:]) indices = indices[np.argsort(predictions[indices])] From acc1a32917427a04d9fb90aa082d58c053ea21ed Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 16:22:28 +0200 Subject: [PATCH 08/92] delete method header since they are per method now --- dianna/dashboard/_shared.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index a5ae0bd6..5e6c2a14 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -70,7 +70,6 @@ def _methods_checkboxes(*, choices: Sequence, key): if st.checkbox(method, key=key + method): methods.append(method) with st.expander(f'Click to modify {method} parameters'): - st.header(method) method_params[method] = _get_params(method, key=key) if not methods: From ac85823aa2a9b4f6a2df43e518f3986bb1cc57cf Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 16:42:53 +0200 Subject: [PATCH 09/92] make method axis label equal size as class axis label --- dianna/dashboard/pages/Images.py | 2 +- dianna/dashboard/pages/Text.py | 2 +- dianna/dashboard/pages/Time_series.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 230b5995..eb96e38f 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -105,7 +105,7 @@ _, *columns = st.columns(column_spec) for col, method in zip(columns, methods): - col.header(method) + col.markdown(f'##### {method}') for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index 4ca25df4..0dce1584 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -94,7 +94,7 @@ _, *columns = st.columns(column_spec) for col, method in zip(columns, methods): - col.header(method) + col.markdown(f'##### {method}') for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index c6779a1c..9fc66d8b 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -127,7 +127,7 @@ def preprocess(data): _, *columns = st.columns(column_spec) for col, method in zip(columns, methods): - col.header(method) + col.markdown(f'##### {method}') for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) From 4d8ef36df0672248150d13bbb350cb8c8ce8fd41 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 17:10:48 +0200 Subject: [PATCH 10/92] move prediction to top ourside of container --- dianna/dashboard/pages/Images.py | 7 +++++-- dianna/dashboard/pages/Text.py | 3 +++ dianna/dashboard/pages/Time_series.py | 4 ++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index eb96e38f..d2047276 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -87,14 +87,17 @@ choices = ('RISE', 'KernelSHAP', 'LIME') +prediction_placeholder = st.empty() + with st.container(border=True): methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb_') with st.spinner('Predicting class'): predictions = predict(model=model, image=image) - top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions, - labels=labels) +with prediction_placeholder: + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions,labels=labels) # check which axis is color channel original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :] diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index 0dce1584..c1e32e45 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -78,6 +78,8 @@ choices = ('RISE', 'LIME') +prediction_placeholder = st.empty() + with st.container(border=True): methods, method_params = _methods_checkboxes(choices=choices, key='Text_cb_') @@ -86,6 +88,7 @@ with st.spinner('Predicting class'): predictions = predict(model=serialized_model, text_input=text_input) +with prediction_placeholder: top_indices, top_labels = _get_top_indices_and_labels( predictions=predictions[0], labels=labels) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 9fc66d8b..71cf9bb1 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -113,12 +113,16 @@ def preprocess(data): choices = ('RISE', 'LIME') +prediction_placeholder = st.empty() + with st.container(border=True): + methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb_') with st.spinner('Predicting class'): predictions = predict(model=serialized_model, ts_data=ts_data_model) +with prediction_placeholder: top_indices, top_labels = _get_top_indices_and_labels( predictions=predictions[0], labels=labels) From 6f89ba3b1de19ddfabc8146737a7fc4e86519bbf Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 17:24:30 +0200 Subject: [PATCH 11/92] center titles --- dianna/dashboard/pages/Images.py | 2 +- dianna/dashboard/pages/Text.py | 2 +- dianna/dashboard/pages/Time_series.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index d2047276..f19e7fb5 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -108,7 +108,7 @@ _, *columns = st.columns(column_spec) for col, method in zip(columns, methods): - col.markdown(f'##### {method}') + col.markdown(f"

{method}

", unsafe_allow_html=True) for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index c1e32e45..fad3c063 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -97,7 +97,7 @@ _, *columns = st.columns(column_spec) for col, method in zip(columns, methods): - col.markdown(f'##### {method}') + col.markdown(f"

{method}

", unsafe_allow_html=True) for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 71cf9bb1..c3bce835 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -131,7 +131,7 @@ def preprocess(data): _, *columns = st.columns(column_spec) for col, method in zip(columns, methods): - col.markdown(f'##### {method}') + col.markdown(f"

{method}

", unsafe_allow_html=True) for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) From 1918b3ae8d7b0200fa4c2998361c2e0d2c11d5ac Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 1 Aug 2024 17:27:47 +0200 Subject: [PATCH 12/92] add some space --- dianna/dashboard/pages/Images.py | 3 +++ dianna/dashboard/pages/Text.py | 3 +++ dianna/dashboard/pages/Time_series.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index f19e7fb5..64a6392b 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -99,6 +99,9 @@ top_indices, top_labels = _get_top_indices_and_labels( predictions=predictions,labels=labels) +st.text("") +st.text("") + # check which axis is color channel original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :] axis_labels = {2: 'channels'} if image.shape[2] <= 3 else {0: 'channels'} diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index fad3c063..41c5b84d 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -92,6 +92,9 @@ top_indices, top_labels = _get_top_indices_and_labels( predictions=predictions[0], labels=labels) +st.text("") +st.text("") + weight = 0.85 / len(methods) column_spec = [0.15, *[weight for _ in methods]] diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index c3bce835..3241c3d4 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -126,6 +126,9 @@ def preprocess(data): top_indices, top_labels = _get_top_indices_and_labels( predictions=predictions[0], labels=labels) +st.text("") +st.text("") + weight = 0.9 / len(methods) column_spec = [0.1, *[weight for _ in methods]] From 4c62ab36781f8feb2c3f2521a53d1de5dd692e56 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 2 Aug 2024 09:51:34 +0200 Subject: [PATCH 13/92] update text label --- dianna/dashboard/_shared.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 5e6c2a14..2025fbde 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -76,6 +76,7 @@ def _methods_checkboxes(*, choices: Sequence, key): # Put the message in the container above message_container.info('Select a method to continue') st.stop() + return methods, method_params @@ -127,7 +128,7 @@ def _get_top_indices_and_labels(*, predictions, labels): top_labels = [labels[i] for i in top_indices] with c1: - st.metric('Predicted class', top_labels[0]) + st.metric('Predicted class:', top_labels[0]) return top_indices, top_labels From eb9af6dc6132878bbac8a84dc536f23a7aa2dc60 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 2 Aug 2024 12:10:22 +0200 Subject: [PATCH 14/92] make the row elements smaller --- dianna/dashboard/_shared.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 2025fbde..64473ab8 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -116,9 +116,9 @@ def _get_top_indices(predictions, n_top): def _get_top_indices_and_labels(*, predictions, labels): - c1, c2 = st.columns(2) + cols = st.columns(4) - with c2: + with cols[-1]: n_top = st.number_input('Number of top classes to show', value=2, min_value=1, @@ -127,7 +127,7 @@ def _get_top_indices_and_labels(*, predictions, labels): top_indices = _get_top_indices(predictions, n_top) top_labels = [labels[i] for i in top_indices] - with c1: + with cols[0]: st.metric('Predicted class:', top_labels[0]) return top_indices, top_labels From 7459ed3050ab7cdb31e88f09a89d5f1519598115 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 2 Aug 2024 12:11:49 +0200 Subject: [PATCH 15/92] add some space --- dianna/dashboard/pages/Images.py | 2 ++ dianna/dashboard/pages/Text.py | 2 ++ dianna/dashboard/pages/Time_series.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 64a6392b..d7c65746 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -87,6 +87,8 @@ choices = ('RISE', 'KernelSHAP', 'LIME') +st.text("") +st.text("") prediction_placeholder = st.empty() with st.container(border=True): diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index 41c5b84d..f5104d20 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -78,6 +78,8 @@ choices = ('RISE', 'LIME') +st.text("") +st.text("") prediction_placeholder = st.empty() with st.container(border=True): diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 3241c3d4..4cbc2bda 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -113,6 +113,8 @@ def preprocess(data): choices = ('RISE', 'LIME') +st.text("") +st.text("") prediction_placeholder = st.empty() with st.container(border=True): From 8432fe3878cba0a5389677bbfc8b40d08e0a6483 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 2 Aug 2024 12:27:19 +0200 Subject: [PATCH 16/92] remove sidebar from home page --- dianna/dashboard/Home.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py index e15dc128..cbd41a07 100644 --- a/dianna/dashboard/Home.py +++ b/dianna/dashboard/Home.py @@ -38,8 +38,6 @@ # Display the content of the selected page if selected == "Home": - add_sidebar_logo() - st.image(str(data_directory / 'logo.png')) st.markdown(""" From 3b7b5f4d060b63adb3137d406a4de72ccf776dd2 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 2 Aug 2024 12:28:38 +0200 Subject: [PATCH 17/92] use simple logo add --- dianna/dashboard/_shared.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 64473ab8..aef4a19b 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -46,13 +46,8 @@ def build_markup_for_logo( def add_sidebar_logo(): - """Based on: https://stackoverflow.com/a/73278825.""" - png_file = data_directory / 'logo.png' - logo_markup = build_markup_for_logo(png_file) - st.markdown( - logo_markup, - unsafe_allow_html=True, - ) + "Upload DIANNA logo to sidebar element" + st.sidebar.image(str(data_directory / 'logo.png')) def _methods_checkboxes(*, choices: Sequence, key): From 310da69ee7aff0ae94a4d74cf3301618a55c9956 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 2 Aug 2024 13:49:00 +0200 Subject: [PATCH 18/92] move predicted class inside container --- dianna/dashboard/pages/Images.py | 8 ++++---- dianna/dashboard/pages/Text.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index d7c65746..23503cba 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -89,17 +89,17 @@ st.text("") st.text("") -prediction_placeholder = st.empty() with st.container(border=True): + prediction_placeholder = st.empty() methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb_') with st.spinner('Predicting class'): predictions = predict(model=model, image=image) -with prediction_placeholder: - top_indices, top_labels = _get_top_indices_and_labels( - predictions=predictions,labels=labels) + with prediction_placeholder: + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions,labels=labels) st.text("") st.text("") diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index f5104d20..5d295077 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -80,9 +80,9 @@ st.text("") st.text("") -prediction_placeholder = st.empty() with st.container(border=True): + prediction_placeholder = st.empty() methods, method_params = _methods_checkboxes(choices=choices, key='Text_cb_') model_runner = MovieReviewsModelRunner(serialized_model) @@ -90,9 +90,9 @@ with st.spinner('Predicting class'): predictions = predict(model=serialized_model, text_input=text_input) -with prediction_placeholder: - top_indices, top_labels = _get_top_indices_and_labels( - predictions=predictions[0], labels=labels) + with prediction_placeholder: + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions[0], labels=labels) st.text("") st.text("") From ed61dbbcf722b11e51a984fda8c39697545ddb07 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 2 Aug 2024 13:49:34 +0200 Subject: [PATCH 19/92] move predicted class inside container and all sidebar logo --- dianna/dashboard/pages/Time_series.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 4cbc2bda..51f85547 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -6,6 +6,7 @@ from _shared import _get_top_indices_and_labels from _shared import _methods_checkboxes from _shared import add_sidebar_logo +from _shared import build_markup_for_logo from _shared import data_directory from _shared import label_directory from _shared import model_directory @@ -16,10 +17,12 @@ from dianna.visualization import plot_timeseries, plot_image import numpy as np -add_sidebar_logo() + + st.title('Time series explanation') +add_sidebar_logo() st.sidebar.header('Input data') input_type = st.sidebar.radio( @@ -115,18 +118,18 @@ def preprocess(data): st.text("") st.text("") -prediction_placeholder = st.empty() -with st.container(border=True): +with st.container(border=True): + prediction_placeholder = st.empty() methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb_') with st.spinner('Predicting class'): predictions = predict(model=serialized_model, ts_data=ts_data_model) -with prediction_placeholder: - top_indices, top_labels = _get_top_indices_and_labels( - predictions=predictions[0], labels=labels) + with prediction_placeholder: + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions[0], labels=labels) st.text("") st.text("") From afc30923570762af9436230f143e19611dd12941 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 5 Aug 2024 08:57:58 +0200 Subject: [PATCH 20/92] use page wide layout instead of small centered --- dianna/dashboard/Home.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py index 91a86707..3041e1cc 100644 --- a/dianna/dashboard/Home.py +++ b/dianna/dashboard/Home.py @@ -6,7 +6,7 @@ st.set_page_config(page_title="Dianna's dashboard", page_icon='📊', - layout='centered', + layout='wide', initial_sidebar_state='auto', menu_items={ 'Get help': From 4a3c451c797f876c5cb5c0643873353c81dbc489 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 7 Aug 2024 16:15:36 +0200 Subject: [PATCH 21/92] fix typo --- dianna/dashboard/pages/Images.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index ff2b3171..87dd8dc6 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -35,7 +35,7 @@ options=('Hand-written digit recognition',), index = None, on_change = reset_method, - key='Image_load example' + key='Image_load_example' ) if load_example == 'Hand-written digit recognition': @@ -131,6 +131,7 @@ with col: with st.spinner(f'Running {method}'): + print('index', index) heatmap = func(serialized_model, image, index, **kwargs) fig, _ = plot_image(heatmap, From 80603d4d03fd022686d4dc734e8fd7f30cb8e3ef Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 7 Aug 2024 16:32:37 +0200 Subject: [PATCH 22/92] fix double checkbox --- dianna/dashboard/pages/Time_series.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 830d6136..6c5fb98e 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -123,7 +123,6 @@ def preprocess(data): choices = ('RISE',) else: choices = ('RISE', 'LIME') -methods = _methods_checkboxes(choices=choices, key='TS_cb_') st.text("") st.text("") From dbdbc67b8249aec3ca1a327fcd69b99a31795d3b Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 7 Aug 2024 16:32:52 +0200 Subject: [PATCH 23/92] fix key label --- dianna/dashboard/pages/Text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index cba16449..03456978 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -34,7 +34,7 @@ options=('Movie sentiment',), index = None, on_change = reset_method, - key='Text_example_check_moviesentiment') + key='Text_load_example') if load_example == 'Movie sentiment': text_input = 'The movie started out great but the ending was dissappointing' From 906b6e915416f0d535707c01f9e65c037da23001 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 8 Aug 2024 13:57:25 +0200 Subject: [PATCH 24/92] fix popping --- dianna/dashboard/_shared.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index f0c8b19b..860da3c5 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -65,7 +65,7 @@ def _methods_checkboxes(*, choices: Sequence, key): if st.checkbox(method, key=key + method): methods.append(method) with st.expander(f'Click to modify {method} parameters'): - method_params[method] = _get_params(method, key=key) + method_params[method] = _get_params(method, key=f'{key}param_') if not methods: # Put the message in the container above @@ -130,10 +130,10 @@ def _get_top_indices_and_labels(*, predictions, labels): def reset_method(): # Clear selection for k in st.session_state.keys(): - if '_cb_' in k: - st.session_state[k] = False - if 'params' in k: + if '_param' in k: st.session_state.pop(k) + elif '_cb' in k: + st.session_state[k] = False def reset_example(): # Clear selection From eb5d3f46a26144155efb5748bb94790e7548c1b9 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 8 Aug 2024 14:05:21 +0200 Subject: [PATCH 25/92] use f strings --- dianna/dashboard/_shared.py | 20 ++++++++++---------- dianna/dashboard/pages/Images.py | 2 +- dianna/dashboard/pages/Text.py | 2 +- dianna/dashboard/pages/Time_series.py | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 860da3c5..e807cc20 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -62,10 +62,10 @@ def _methods_checkboxes(*, choices: Sequence, key): for col, method in zip(st.columns(n_choices), choices): with col: - if st.checkbox(method, key=key + method): + if st.checkbox(method, key=f'{key}_{method}'): methods.append(method) with st.expander(f'Click to modify {method} parameters'): - method_params[method] = _get_params(method, key=f'{key}param_') + method_params[method] = _get_params(method, key=f'{key}_param') if not methods: # Put the message in the container above @@ -79,24 +79,24 @@ def _get_params(method: str, key): if method == 'RISE': return { 'n_masks': - st.number_input('Number of masks', value=1000, key=key + method + 'nmasks'), + st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'), 'feature_res': - st.number_input('Feature resolution', value=6, key=key + method + 'fr'), + st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'), 'p_keep': - st.number_input('Probability to be kept unmasked', value=0.1, key=key + method + 'pkeep'), + st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'), } elif method == 'KernelSHAP': return { - 'nsamples': st.number_input('Number of samples', value=1000, key=key + method + 'nsamp'), - 'background': st.number_input('Background', value=0, key=key + method + 'background'), - 'n_segments': st.number_input('Number of segments', value=200, key=key + method + 'nseg'), - 'sigma': st.number_input('σ', value=0, key=key + method + 'sigma'), + 'nsamples': st.number_input('Number of samples', value=1000, key=f'{key}_{method}_nsamp'), + 'background': st.number_input('Background', value=0, key=f'{key}_{method}_background'), + 'n_segments': st.number_input('Number of segments', value=200, key=f'{key}_{method}_nseg'), + 'sigma': st.number_input('σ', value=0, key=f'{key}_{method}_sigma'), } elif method == 'LIME': return { - 'random_state': st.number_input('Random state', value=2, key=key + method + 'rs'), + 'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'), } else: diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 87dd8dc6..5b8691e4 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -95,7 +95,7 @@ with st.container(border=True): prediction_placeholder = st.empty() - methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb_') + methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb') with st.spinner('Predicting class'): predictions = predict(model=model, image=image) diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index 03456978..120f8406 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -86,7 +86,7 @@ with st.container(border=True): prediction_placeholder = st.empty() - methods, method_params = _methods_checkboxes(choices=choices, key='Text_cb_') + methods, method_params = _methods_checkboxes(choices=choices, key='Text_cb') model_runner = MovieReviewsModelRunner(serialized_model) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 6c5fb98e..62f4503c 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -130,7 +130,7 @@ def preprocess(data): with st.container(border=True): prediction_placeholder = st.empty() - methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb_') + methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb') with st.spinner('Predicting class'): predictions = predict(model=serialized_model, ts_data=ts_data_model) From 637dae58a77d3f446dbc8e6e0ce82eb1d0c64766 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 8 Aug 2024 14:08:23 +0200 Subject: [PATCH 26/92] remove print statement --- dianna/dashboard/pages/Images.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 5b8691e4..2ab97f9d 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -131,7 +131,6 @@ with col: with st.spinner(f'Running {method}'): - print('index', index) heatmap = func(serialized_model, image, index, **kwargs) fig, _ = plot_image(heatmap, From 4d5b04adf819702dcebbd6f91eff499fdb3f7d7d Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 8 Aug 2024 17:41:01 +0200 Subject: [PATCH 27/92] initialize tab for tabular --- dianna/dashboard/Home.py | 3 ++- dianna/dashboard/pages/Tabular.py | 0 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 dianna/dashboard/pages/Tabular.py diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py index c86fd264..b57b759b 100644 --- a/dianna/dashboard/Home.py +++ b/dianna/dashboard/Home.py @@ -22,6 +22,7 @@ pages = { "Home": "home", "Images": "pages.Images", + "Tabular": "pages.Tabular", "Text": "pages.Text", "Time series": "pages.Time_series" } @@ -30,7 +31,7 @@ selected = option_menu( menu_title=None, options=list(pages.keys()), - icons=["house", "camera", "alphabet", "clock"], + icons=["house", "camera", "table", "alphabet", "clock"], menu_icon="cast", default_index=0, orientation="horizontal" diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py new file mode 100644 index 00000000..e69de29b From 083ebc8bdf78733b8c479fcdeb01639e34796eee Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 8 Aug 2024 17:41:40 +0200 Subject: [PATCH 28/92] add default exception for tabular --- dianna/dashboard/Home.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py index b57b759b..ffbfdbb6 100644 --- a/dianna/dashboard/Home.py +++ b/dianna/dashboard/Home.py @@ -71,6 +71,10 @@ for k in st.session_state.keys(): if 'Image' in k: st.session_state.pop(k, None) + if selected != 'Tabular': + for k in st.session_state.keys(): + if 'Tabular' in k: + st.session_state.pop(k, None) if selected != 'Text': for k in st.session_state.keys(): if 'Text' in k: From a60f3e83b6dd1e6b28a2a309cb2d21dbbfa245d8 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 8 Aug 2024 17:45:10 +0200 Subject: [PATCH 29/92] Set default for tabular page --- dianna/dashboard/pages/Tabular.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index e69de29b..56cc29dc 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -0,0 +1,27 @@ +import streamlit as st +from _image_utils import open_image +from _model_utils import load_labels +from _model_utils import load_model +from _models_image import explain_image_dispatcher +from _models_image import predict +from _shared import _get_top_indices_and_labels +from _shared import _methods_checkboxes +from _shared import add_sidebar_logo +from _shared import reset_example +from _shared import reset_method +from dianna.utils.downloader import download +from dianna.visualization import plot_image + +add_sidebar_logo() + +st.title('Tabular data explanation') + +st.sidebar.header('Input data') + +input_type = st.sidebar.radio( + label='Select which input to use', + options = ('Use an example', 'Use your own data'), + index = None, + on_change = reset_example, + key = 'Tabular_input_type' + ) \ No newline at end of file From 8ac3e51a795be37a642e46679e39baa5f1c0c697 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 17:02:15 +0200 Subject: [PATCH 30/92] add tabular page reference --- dianna/dashboard/Home.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py index 22100f62..dc502278 100644 --- a/dianna/dashboard/Home.py +++ b/dianna/dashboard/Home.py @@ -49,9 +49,10 @@ ### Pages - - Images - - Text - - Time series + - Image data + - Tabular data + - Text data + - Time series data ### More information From 70fb93de37cb94a9b163b0e9c20380aa43d0872e Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 17:09:01 +0200 Subject: [PATCH 31/92] add tabular example setup --- dianna/dashboard/pages/Tabular.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 56cc29dc..698082bb 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -24,4 +24,17 @@ index = None, on_change = reset_example, key = 'Tabular_input_type' - ) \ No newline at end of file + ) + +# Use the examples +if input_type == 'Use an example': + """load_example = st.sidebar.radio( + label='Use example', + options=(''), + index = None, + on_change = reset_method, + key='Tabular_load_example')""" + st.info("No examples availble yet") + st.stop() + +st.stop() \ No newline at end of file From f3629bb291c687293af52b43e1130e5184b5b591 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 17:16:32 +0200 Subject: [PATCH 32/92] add upload menu --- dianna/dashboard/pages/Tabular.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 698082bb..42b639c3 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -37,4 +37,20 @@ st.info("No examples availble yet") st.stop() +# Option to upload your own data +if input_type == 'Use your own data': + tabular_data_file = st.sidebar.file_uploader('Select tabular data', type='csv') + tabular_model_file = st.sidebar.file_uploader('Select model', + type='onnx') + tabular_label_file = st.sidebar.file_uploader('Select labels', + type='txt') + +if input_type is None: + st.info('Select which input type to use in the left panel to continue') + st.stop() + +if not (tabular_data_file and tabular_model_file and tabular_label_file): + st.info('Add your input data in the left panel to continue') + st.stop() + st.stop() \ No newline at end of file From ecb61618a97fa7b834c190d13ceedbf0f3dc1396 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 18:04:00 +0200 Subject: [PATCH 33/92] fix imports and draft prediction --- dianna/dashboard/pages/Tabular.py | 32 +++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 42b639c3..39314583 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -1,16 +1,14 @@ import streamlit as st -from _image_utils import open_image from _model_utils import load_labels from _model_utils import load_model -from _models_image import explain_image_dispatcher -from _models_image import predict +from _models_tabular import predict from _shared import _get_top_indices_and_labels from _shared import _methods_checkboxes from _shared import add_sidebar_logo from _shared import reset_example from _shared import reset_method from dianna.utils.downloader import download -from dianna.visualization import plot_image +from dianna.visualization import plot_tabular add_sidebar_logo() @@ -53,4 +51,30 @@ st.info('Add your input data in the left panel to continue') st.stop() +model = load_model(tabular_model_file) +serialized_model = model.SerializeToString() + +labels = load_labels(tabular_label_file) + +choices = ('RISE', 'LIME', 'KernelSHAP') + +st.text("") +st.text("") + +# Get predictions and create parameter box +with st.container(border=True): + prediction_placeholder = st.empty() + methods, method_params = _methods_checkboxes(choices=choices, key='Tabular_cb') + + with st.spinner('Predicting class'): + predictions = predict(model=serialized_model, tabular_input=tabular_data_file) + + with prediction_placeholder: + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions, labels=labels) + +st.text("") +st.text("") + + st.stop() \ No newline at end of file From e494c62c7b8b29a21d60f200bc54b74cd251c265 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 18:04:14 +0200 Subject: [PATCH 34/92] draft model --- dianna/dashboard/_models_tabular.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 dianna/dashboard/_models_tabular.py diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py new file mode 100644 index 00000000..93286489 --- /dev/null +++ b/dianna/dashboard/_models_tabular.py @@ -0,0 +1,54 @@ +import numpy as np +import streamlit as st +from dianna import explain_tabular +from dianna.utils.onnx_runner import SimpleModelRunner + + +@st.cache_data +def predict(*, model, tabular_input): + model_runner = SimpleModelRunner(model) + predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32)) + return predictions + + +@st.cache_data +def _run_rise_tabular(_model, table, **kwargs): + relevances = explain_tabular( + _model, + table, + method='RISE', + **kwargs, + ) + return relevances + + +@st.cache_data +def _run_lime_tabular(_model, table, **kwargs): + relevances = explain_tabular( + _model, + table, + method='LIME', + **kwargs, + ) + return relevances + +@st.cache_data +def _run_kernelshap_tabular(model, image, i, **kwargs): + # Kernelshap interface is different. Write model to temporary file. + with tempfile.NamedTemporaryFile() as f: + f.write(model) + f.flush() + relevances = explain_tabular( + _model, + table, + method='KernelSHAP', + **kwargs, + ) + return relevances + + +explain_text_dispatcher = { + 'RISE': _run_rise_tabular, + 'LIME': _run_lime_tabular, + 'KernelSHAP': _run_kernelshap_tabular +} From 80c284f26c45b4dbb805849bd8997c23df0201b1 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 22:45:32 +0200 Subject: [PATCH 35/92] add table visualisation with row selection option --- dianna/dashboard/pages/Tabular.py | 37 ++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 39314583..ab1951f4 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -1,4 +1,5 @@ import streamlit as st +from _model_utils import load_data from _model_utils import load_labels from _model_utils import load_model from _models_tabular import predict @@ -8,7 +9,10 @@ from _shared import reset_example from _shared import reset_method from dianna.utils.downloader import download +from dianna.utils.onnx_runner import SimpleModelRunner from dianna.visualization import plot_tabular +import pandas as pd +from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode add_sidebar_logo() @@ -51,8 +55,12 @@ st.info('Add your input data in the left panel to continue') st.stop() -model = load_model(tabular_model_file) -serialized_model = model.SerializeToString() +data = load_data(tabular_data_file) + +#model = load_model(tabular_model_file) +#serialized_model = model.SerializeToString() + +model = SimpleModelRunner(tabular_model_file) labels = load_labels(tabular_label_file) @@ -63,15 +71,28 @@ # Get predictions and create parameter box with st.container(border=True): - prediction_placeholder = st.empty() + #prediction_placeholder = st.empty() methods, method_params = _methods_checkboxes(choices=choices, key='Tabular_cb') + #prediction_placeholder = 'hello' + +st.info("Select the input data either by clicking the corresponding row in the table or input the row index above to continue.") + +# Configure Ag-Grid options +gb = GridOptionsBuilder.from_dataframe(data) +gb.configure_selection('single') +grid_options = gb.build() + +# Display the grid with the DataFrame +grid_response = AgGrid( + data, + gridOptions=grid_options, + update_mode=GridUpdateMode.SELECTION_CHANGED, + theme='streamlit' +) + +selected_row = grid_response['selected_rows'] - with st.spinner('Predicting class'): - predictions = predict(model=serialized_model, tabular_input=tabular_data_file) - with prediction_placeholder: - top_indices, top_labels = _get_top_indices_and_labels( - predictions=predictions, labels=labels) st.text("") st.text("") From 00d938606d81a50290c8b294bd4382ed1abd7877 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 22:46:20 +0200 Subject: [PATCH 36/92] add load data for csv to pandas --- dianna/dashboard/_model_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index 272e4a40..7480b42f 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -1,8 +1,18 @@ from pathlib import Path import numpy as np +import pandas as pd import onnx +def load_data(file): + """Open data from a file and returns it as pandas DataFrame""" + df = pd.read_csv(file, parse_dates=True) + + # Add index column + df.insert(0, 'Index', df.index) + return df + + def preprocess_function(image): """For LIME: we divided the input data by 256 for the model (binary mnist) and LIME needs RGB values.""" return (image / 256).astype(np.float32) From 04574c48327e642f1b62c797cd29921969398451 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 23:46:46 +0200 Subject: [PATCH 37/92] update for non-classifier model --- dianna/dashboard/_shared.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index e807cc20..0dc2ead5 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -113,17 +113,23 @@ def _get_top_indices(predictions, n_top): def _get_top_indices_and_labels(*, predictions, labels): cols = st.columns(4) - with cols[-1]: - n_top = st.number_input('Number of top classes to show', - value=2, - min_value=1, - max_value=len(labels)) - - top_indices = _get_top_indices(predictions, n_top) - top_labels = [labels[i] for i in top_indices] - - with cols[0]: - st.metric('Predicted class:', top_labels[0]) + if labels is not None: + with cols[-1]: + n_top = st.number_input('Number of top classes to show', + value=1, + min_value=1, + max_value=len(labels)) + + top_indices = _get_top_indices(predictions, n_top) + top_labels = [labels[i] for i in top_indices] + + with cols[0]: + st.metric('Predicted class:', top_labels[0]) + else: + # If not a classifier, only return the predicted value + top_indices = top_labels = 0 + with cols[0]: + st.metric('Predicted value:', f"{predictions[0]:.2f}") return top_indices, top_labels From 852570e23c212a89f4302c1cb3665d41fe8c635d Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 23:48:48 +0200 Subject: [PATCH 38/92] allow for uploads without labels if not classifier --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index ab1951f4..93c28338 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -51,7 +51,7 @@ st.info('Select which input type to use in the left panel to continue') st.stop() -if not (tabular_data_file and tabular_model_file and tabular_label_file): +if not (tabular_data_file and tabular_model_file): st.info('Add your input data in the left panel to continue') st.stop() From bd200457d62f89bbb2147b90f42ca6c4d1cbd17c Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 23:49:46 +0200 Subject: [PATCH 39/92] allow for labels to be empty --- dianna/dashboard/pages/Tabular.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 93c28338..2e6931e2 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -62,7 +62,10 @@ model = SimpleModelRunner(tabular_model_file) -labels = load_labels(tabular_label_file) +if tabular_label_file: + labels = load_labels(tabular_label_file) +else: + labels = None choices = ('RISE', 'LIME', 'KernelSHAP') From 6714801087c071b8ef730aee194c604bd5d0b71e Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 23:50:55 +0200 Subject: [PATCH 40/92] reset --- dianna/dashboard/pages/Tabular.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 2e6931e2..2ced2a11 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -74,9 +74,8 @@ # Get predictions and create parameter box with st.container(border=True): - #prediction_placeholder = st.empty() + prediction_placeholder = st.empty() methods, method_params = _methods_checkboxes(choices=choices, key='Tabular_cb') - #prediction_placeholder = 'hello' st.info("Select the input data either by clicking the corresponding row in the table or input the row index above to continue.") From 8886611a1ae79953aa369f7d348cea67ba9cfabd Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 23:51:34 +0200 Subject: [PATCH 41/92] show prediction --- dianna/dashboard/pages/Tabular.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 2ced2a11..774ac488 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -92,9 +92,20 @@ theme='streamlit' ) -selected_row = grid_response['selected_rows'] +selected_row = grid_response['selected_rows']['Index'][0] +selected_data = data.iloc[selected_row, 1:].to_numpy() +if selected_row is not None: + with st.spinner('Predicting class'): + predictions = predict(modelfile=tabular_model_file, tabular_input=selected_data) + + with prediction_placeholder: + top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions[0], labels=labels) + +else: + st.write('No row selected') st.text("") st.text("") From 5e0727dddaf3332d3fdcb9a6e4ed5951aa4ad204 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 9 Aug 2024 23:53:32 +0200 Subject: [PATCH 42/92] temp for weather data --- dianna/dashboard/_model_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index 7480b42f..a465c131 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -7,10 +7,11 @@ def load_data(file): """Open data from a file and returns it as pandas DataFrame""" df = pd.read_csv(file, parse_dates=True) - + # FIXME: only for weather example + data = df.drop(columns=['DATE', 'MONTH'])[:-1] # Add index column - df.insert(0, 'Index', df.index) - return df + data.insert(0, 'Index', data.index) + return data def preprocess_function(image): From 78d8826d72f1bf0e499bc7e575d51e23af72c27f Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 14:51:02 +0200 Subject: [PATCH 43/92] only if input is given --- dianna/dashboard/pages/Tabular.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 774ac488..38d027b2 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -92,11 +92,9 @@ theme='streamlit' ) -selected_row = grid_response['selected_rows']['Index'][0] -selected_data = data.iloc[selected_row, 1:].to_numpy() - - -if selected_row is not None: +if grid_response['selected_rows'] is not None: + selected_row = grid_response['selected_rows']['Index'][0] + selected_data = data.iloc[selected_row, 1:].to_numpy() with st.spinner('Predicting class'): predictions = predict(modelfile=tabular_model_file, tabular_input=selected_data) From 891ef84c54da39ad4432176b0d65564f7885f6e2 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 14:51:46 +0200 Subject: [PATCH 44/92] move info message --- dianna/dashboard/pages/Tabular.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 38d027b2..ef67309f 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -77,7 +77,6 @@ prediction_placeholder = st.empty() methods, method_params = _methods_checkboxes(choices=choices, key='Tabular_cb') -st.info("Select the input data either by clicking the corresponding row in the table or input the row index above to continue.") # Configure Ag-Grid options gb = GridOptionsBuilder.from_dataframe(data) @@ -103,7 +102,7 @@ predictions=predictions[0], labels=labels) else: - st.write('No row selected') + st.info("Select the input data either by clicking the corresponding row in the table or input the row index above to continue.") st.text("") st.text("") From a8bcbb22be8041986da0e55d90ebbd4525d9feba Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 14:52:12 +0200 Subject: [PATCH 45/92] change variable names --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index ef67309f..e78915bc 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -95,7 +95,7 @@ selected_row = grid_response['selected_rows']['Index'][0] selected_data = data.iloc[selected_row, 1:].to_numpy() with st.spinner('Predicting class'): - predictions = predict(modelfile=tabular_model_file, tabular_input=selected_data) + predictions = predict(model_path=tabular_model_file.name, tabular_input=selected_data) with prediction_placeholder: top_indices, top_labels = _get_top_indices_and_labels( From 5b03fe95a9346ed9c380ee81557e27e5c0f451b6 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 15:59:46 +0200 Subject: [PATCH 46/92] model prediction for tabular --- dianna/dashboard/_models_tabular.py | 3 ++- dianna/dashboard/pages/Tabular.py | 7 ++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 93286489..9930323c 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -6,7 +6,8 @@ @st.cache_data def predict(*, model, tabular_input): - model_runner = SimpleModelRunner(model) + serialized_model = model.SerializeToString() + model_runner = SimpleModelRunner(serialized_model) predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32)) return predictions diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index e78915bc..0952e748 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -57,10 +57,7 @@ data = load_data(tabular_data_file) -#model = load_model(tabular_model_file) -#serialized_model = model.SerializeToString() - -model = SimpleModelRunner(tabular_model_file) +model = load_model(tabular_model_file) if tabular_label_file: labels = load_labels(tabular_label_file) @@ -95,7 +92,7 @@ selected_row = grid_response['selected_rows']['Index'][0] selected_data = data.iloc[selected_row, 1:].to_numpy() with st.spinner('Predicting class'): - predictions = predict(model_path=tabular_model_file.name, tabular_input=selected_data) + predictions = predict(model=model, tabular_input=selected_data) with prediction_placeholder: top_indices, top_labels = _get_top_indices_and_labels( From 5d64e69094373daeb45c878946652f51d0942aba Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 16:03:38 +0200 Subject: [PATCH 47/92] use iloc --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 0952e748..7dac7bfc 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -89,7 +89,7 @@ ) if grid_response['selected_rows'] is not None: - selected_row = grid_response['selected_rows']['Index'][0] + selected_row = grid_response['selected_rows']['Index'].iloc[0] selected_data = data.iloc[selected_row, 1:].to_numpy() with st.spinner('Predicting class'): predictions = predict(model=model, tabular_input=selected_data) From 5ba7624b8a2233361a4bf2123f6e45e27db281e0 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 16:40:16 +0200 Subject: [PATCH 48/92] import dispatcher --- dianna/dashboard/pages/Tabular.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 7dac7bfc..95080d88 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -3,6 +3,7 @@ from _model_utils import load_labels from _model_utils import load_model from _models_tabular import predict +from _models_tabular import explain_tabular_dispatcher from _shared import _get_top_indices_and_labels from _shared import _methods_checkboxes from _shared import add_sidebar_logo From 4d6111891b5c5c3739e0707969524fdc5bd2399f Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 16:41:40 +0200 Subject: [PATCH 49/92] update for tabular --- dianna/dashboard/_models_tabular.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 9930323c..225090a1 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -34,10 +34,10 @@ def _run_lime_tabular(_model, table, **kwargs): return relevances @st.cache_data -def _run_kernelshap_tabular(model, image, i, **kwargs): +def _run_kernelshap_tabular(_model, table, i, **kwargs): # Kernelshap interface is different. Write model to temporary file. with tempfile.NamedTemporaryFile() as f: - f.write(model) + f.write(_model) f.flush() relevances = explain_tabular( _model, @@ -48,7 +48,7 @@ def _run_kernelshap_tabular(model, image, i, **kwargs): return relevances -explain_text_dispatcher = { +explain_tabular_dispatcher = { 'RISE': _run_rise_tabular, 'LIME': _run_lime_tabular, 'KernelSHAP': _run_kernelshap_tabular From 05a8d37e2c5c69a3f722f92b3de45b82b81716a9 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 16:42:15 +0200 Subject: [PATCH 50/92] fix location of serialization --- dianna/dashboard/_models_tabular.py | 3 +-- dianna/dashboard/pages/Tabular.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 225090a1..753813ee 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -6,8 +6,7 @@ @st.cache_data def predict(*, model, tabular_input): - serialized_model = model.SerializeToString() - model_runner = SimpleModelRunner(serialized_model) + model_runner = SimpleModelRunner(model) predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32)) return predictions diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 95080d88..f8bbc1a4 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -59,6 +59,7 @@ data = load_data(tabular_data_file) model = load_model(tabular_model_file) +serialized_model = model.SerializeToString() if tabular_label_file: labels = load_labels(tabular_label_file) @@ -93,7 +94,7 @@ selected_row = grid_response['selected_rows']['Index'].iloc[0] selected_data = data.iloc[selected_row, 1:].to_numpy() with st.spinner('Predicting class'): - predictions = predict(model=model, tabular_input=selected_data) + predictions = predict(model=serialized_model, tabular_input=selected_data) with prediction_placeholder: top_indices, top_labels = _get_top_indices_and_labels( From 6f2ed439edf284ec1c59c567e4229cdb8f9ab8c6 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 16:42:22 +0200 Subject: [PATCH 51/92] add import --- dianna/dashboard/_models_tabular.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 753813ee..b28153a1 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -1,5 +1,6 @@ import numpy as np import streamlit as st +import tempfile from dianna import explain_tabular from dianna.utils.onnx_runner import SimpleModelRunner From 615f1f503978f5f16d43f66b5e10c5332b0adb27 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Mon, 12 Aug 2024 16:42:46 +0200 Subject: [PATCH 52/92] set empty string if not a classifier --- dianna/dashboard/_shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 0dc2ead5..560d5ce6 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -127,7 +127,7 @@ def _get_top_indices_and_labels(*, predictions, labels): st.metric('Predicted class:', top_labels[0]) else: # If not a classifier, only return the predicted value - top_indices = top_labels = 0 + top_indices = top_labels = " " with cols[0]: st.metric('Predicted value:', f"{predictions[0]:.2f}") From 91a04af3374d618c6416e2b18b46557f8b774f43 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 11:12:49 +0200 Subject: [PATCH 53/92] tmp --- dianna/dashboard/pages/Tabular.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index f8bbc1a4..1dda3a0a 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -106,5 +106,34 @@ st.text("") st.text("") +weight = 0.85 / len(methods) +column_spec = [0.15, *[weight for _ in methods]] + +_, *columns = st.columns(column_spec) +for col, method in zip(columns, methods): + col.markdown(f"

{method}

", unsafe_allow_html=True) + +for index, label in zip(top_indices, top_labels): + index_col, *columns = st.columns(column_spec) + + index_col.markdown(f'##### Class: {label}') + + for col, method in zip(columns, methods): + kwargs = method_params[method].copy() + kwargs['labels'] = [index] + + func = explain_tabular_dispatcher[method] + + + with col: + with st.spinner(f'Running {method}'): + relevances = func(serialized_model, data, **kwargs) + st.stop() + #fig, _ = highlight_text(explanation=relevances[0], show_plot=False) + #st.pyplot(fig) + + # add some white space to separate rows + st.markdown('') + st.stop() \ No newline at end of file From a58d4722c624d259bacbb54b1d5c7b258bd57cf4 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 14:17:48 +0200 Subject: [PATCH 54/92] add required data type --- dianna/dashboard/pages/Tabular.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 1dda3a0a..50a6db74 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -44,7 +44,8 @@ if input_type == 'Use your own data': tabular_data_file = st.sidebar.file_uploader('Select tabular data', type='csv') tabular_model_file = st.sidebar.file_uploader('Select model', - type='onnx') + type='onnx'), + tabular_trainingdata_file = st.sidebar.file_uploader('Select training data', type='npy') tabular_label_file = st.sidebar.file_uploader('Select labels', type='txt') From c352994b6059dfcc9673874385295401f394f94f Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:31:53 +0200 Subject: [PATCH 55/92] load training data --- dianna/dashboard/pages/Tabular.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 50a6db74..02e8d4f8 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -2,6 +2,7 @@ from _model_utils import load_data from _model_utils import load_labels from _model_utils import load_model +from _model_utils import load_training_data from _models_tabular import predict from _models_tabular import explain_tabular_dispatcher from _shared import _get_top_indices_and_labels @@ -43,17 +44,15 @@ # Option to upload your own data if input_type == 'Use your own data': tabular_data_file = st.sidebar.file_uploader('Select tabular data', type='csv') - tabular_model_file = st.sidebar.file_uploader('Select model', - type='onnx'), - tabular_trainingdata_file = st.sidebar.file_uploader('Select training data', type='npy') - tabular_label_file = st.sidebar.file_uploader('Select labels', - type='txt') + tabular_model_file = st.sidebar.file_uploader('Select model', type='onnx') + tabular_training_data_file = st.sidebar.file_uploader('Select training data', type='npy') + tabular_label_file = st.sidebar.file_uploader('Select labels', type='txt') if input_type is None: st.info('Select which input type to use in the left panel to continue') st.stop() -if not (tabular_data_file and tabular_model_file): +if not (tabular_data_file and tabular_model_file and tabular_training_data_file): st.info('Add your input data in the left panel to continue') st.stop() @@ -62,6 +61,8 @@ model = load_model(tabular_model_file) serialized_model = model.SerializeToString() +training_data = load_training_data(tabular_training_data_file) + if tabular_label_file: labels = load_labels(tabular_label_file) else: From 924b38f96283cbc55837611179b9af7614785dd4 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:33:20 +0200 Subject: [PATCH 56/92] delete unused import --- dianna/dashboard/pages/Tabular.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 02e8d4f8..46c6e69f 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -13,7 +13,6 @@ from dianna.utils.downloader import download from dianna.utils.onnx_runner import SimpleModelRunner from dianna.visualization import plot_tabular -import pandas as pd from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode add_sidebar_logo() From 880f0dbecc84d673b406c83dcfbf4ec5556d7606 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:33:43 +0200 Subject: [PATCH 57/92] convert to flaot for explainer --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 46c6e69f..d31bd5d6 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -93,7 +93,7 @@ if grid_response['selected_rows'] is not None: selected_row = grid_response['selected_rows']['Index'].iloc[0] - selected_data = data.iloc[selected_row, 1:].to_numpy() + selected_data = data.iloc[selected_row, 1:].to_numpy(dtype=np.float32) with st.spinner('Predicting class'): predictions = predict(model=serialized_model, tabular_input=selected_data) From 10074643b8f9d45f27af7df395bb3b1174a7e0cb Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:34:32 +0200 Subject: [PATCH 58/92] provide correct input --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index d31bd5d6..732cf45f 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -128,7 +128,7 @@ with col: with st.spinner(f'Running {method}'): - relevances = func(serialized_model, data, **kwargs) + relevances = func(serialized_model, selected_data, training_data, **kwargs) st.stop() #fig, _ = highlight_text(explanation=relevances[0], show_plot=False) #st.pyplot(fig) From 7be7045cbdcbc72af8b90895acfd8102d1bb4803 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:35:00 +0200 Subject: [PATCH 59/92] add break point --- dianna/dashboard/pages/Tabular.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 732cf45f..947e460f 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -1,3 +1,4 @@ +import numpy as np import streamlit as st from _model_utils import load_data from _model_utils import load_labels @@ -103,6 +104,7 @@ else: st.info("Select the input data either by clicking the corresponding row in the table or input the row index above to continue.") + st.stop() st.text("") st.text("") @@ -137,4 +139,4 @@ st.markdown('') -st.stop() \ No newline at end of file +st.stop() From aa16a2114160da6bd02b8cd2a0dd58bb313379a5 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:59:20 +0200 Subject: [PATCH 60/92] add mode --- dianna/dashboard/pages/Tabular.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 947e460f..96a2230c 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -46,7 +46,7 @@ tabular_data_file = st.sidebar.file_uploader('Select tabular data', type='csv') tabular_model_file = st.sidebar.file_uploader('Select model', type='onnx') tabular_training_data_file = st.sidebar.file_uploader('Select training data', type='npy') - tabular_label_file = st.sidebar.file_uploader('Select labels', type='txt') + tabular_label_file = st.sidebar.file_uploader('Select labels in case of classification model', type='txt') if input_type is None: st.info('Select which input type to use in the left panel to continue') @@ -65,8 +65,10 @@ if tabular_label_file: labels = load_labels(tabular_label_file) + mode = 'classification' else: labels = None + mode = 'regression' choices = ('RISE', 'LIME', 'KernelSHAP') @@ -124,6 +126,7 @@ for col, method in zip(columns, methods): kwargs = method_params[method].copy() kwargs['labels'] = [index] + kwargs['mode'] = mode func = explain_tabular_dispatcher[method] From 87fd6af2a537e789213b7e5f953ab251d98c8e5e Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:59:38 +0200 Subject: [PATCH 61/92] add training data argument --- dianna/dashboard/_models_tabular.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index b28153a1..b6950e9d 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -13,28 +13,30 @@ def predict(*, model, tabular_input): @st.cache_data -def _run_rise_tabular(_model, table, **kwargs): +def _run_rise_tabular(_model, table, training_data, **kwargs): relevances = explain_tabular( _model, table, method='RISE', + training_data=training_data, **kwargs, ) return relevances @st.cache_data -def _run_lime_tabular(_model, table, **kwargs): +def _run_lime_tabular(_model, table, training_data, **kwargs): relevances = explain_tabular( _model, table, method='LIME', + training_data=training_data, **kwargs, ) return relevances @st.cache_data -def _run_kernelshap_tabular(_model, table, i, **kwargs): +def _run_kernelshap_tabular(_model, table, training_data, **kwargs): # Kernelshap interface is different. Write model to temporary file. with tempfile.NamedTemporaryFile() as f: f.write(_model) @@ -43,6 +45,7 @@ def _run_kernelshap_tabular(_model, table, i, **kwargs): _model, table, method='KernelSHAP', + training_data=training_data, **kwargs, ) return relevances From 0de7c9f713bdbe1735f2ba6cf7fe0ffbf8de49a4 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 16:59:56 +0200 Subject: [PATCH 62/92] add load training data function --- dianna/dashboard/_model_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index a465c131..40a331c0 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -40,3 +40,7 @@ def load_labels(file): if labels is None or labels == ['']: raise ValueError(labels) return labels + + +def load_training_data(file): + return np.load(file, allow_pickle=False) From 66e907908f4a9fe511846b30d88114effb04a7cb Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 17:10:22 +0200 Subject: [PATCH 63/92] plot relevances results --- dianna/dashboard/pages/Tabular.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 96a2230c..2e2aeb2c 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -134,9 +134,8 @@ with col: with st.spinner(f'Running {method}'): relevances = func(serialized_model, selected_data, training_data, **kwargs) - st.stop() - #fig, _ = highlight_text(explanation=relevances[0], show_plot=False) - #st.pyplot(fig) + fig, _ = plot_tabular(x=relevances, y=data[:1].columns, num_features=10, show_plot=False) + st.pyplot(fig) # add some white space to separate rows st.markdown('') From 004023bfbb84f4562039c17ae175f8d86c398a49 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Tue, 13 Aug 2024 17:11:49 +0200 Subject: [PATCH 64/92] only show predicted class if classification model --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 2e2aeb2c..1279d28d 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -121,7 +121,7 @@ for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) - index_col.markdown(f'##### Class: {label}') + if mode == 'classification': index_col.markdown(f'##### Class: {label}') for col, method in zip(columns, methods): kwargs = method_params[method].copy() From 924472c94d2d9af51f8999bd81df44b06b43180c Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 14 Aug 2024 11:08:03 +0200 Subject: [PATCH 65/92] only add feature names kwarg for LIME --- dianna/dashboard/pages/Tabular.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 1279d28d..fcb7d463 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -127,6 +127,8 @@ kwargs = method_params[method].copy() kwargs['labels'] = [index] kwargs['mode'] = mode + if method == 'LIME': + kwargs['_feature_names']=data[:1].columns.to_list() func = explain_tabular_dispatcher[method] From 3ac39497c02d0b3aa0d93741c06acd1361f27c37 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 14 Aug 2024 11:09:15 +0200 Subject: [PATCH 66/92] add comment for dashboard --- dianna/methods/lime_tabular.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dianna/methods/lime_tabular.py b/dianna/methods/lime_tabular.py index f6b4b1fd..ff7864c7 100644 --- a/dianna/methods/lime_tabular.py +++ b/dianna/methods/lime_tabular.py @@ -63,6 +63,7 @@ def __init__( LimeTabularExplainer, kwargs) # temporary solution for setting num_features and top_labels + # when fixed, also fix in dashboard Tabular.py -> _feature_names self.num_features = len(feature_names) self.explainer = LimeTabularExplainer( From be494763a6a4c31ff33321771470f946c10f1546 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 21 Aug 2024 10:30:20 +0200 Subject: [PATCH 67/92] make sure data is floats --- dianna/dashboard/_model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index 40a331c0..8db3ed47 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -43,4 +43,4 @@ def load_labels(file): def load_training_data(file): - return np.load(file, allow_pickle=False) + return np.float32(np.load(file, allow_pickle=False)) From b108ca460e24577aacbf35463a8268c9eb6c66a7 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 21 Aug 2024 15:08:12 +0200 Subject: [PATCH 68/92] add feature names --- dianna/dashboard/_models_tabular.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index b6950e9d..87f29c89 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -25,12 +25,13 @@ def _run_rise_tabular(_model, table, training_data, **kwargs): @st.cache_data -def _run_lime_tabular(_model, table, training_data, **kwargs): +def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs): relevances = explain_tabular( _model, table, method='LIME', training_data=training_data, + feature_names=_feature_names, **kwargs, ) return relevances From b14e9285bc762a65369f6979accaf179982aaa00 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Wed, 21 Aug 2024 15:44:36 +0200 Subject: [PATCH 69/92] convert to float32 --- dianna/utils/onnx_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dianna/utils/onnx_runner.py b/dianna/utils/onnx_runner.py index b6c0ea8b..eb82fee4 100644 --- a/dianna/utils/onnx_runner.py +++ b/dianna/utils/onnx_runner.py @@ -1,4 +1,5 @@ import onnxruntime as ort +import numpy as np class SimpleModelRunner: @@ -29,6 +30,6 @@ def __call__(self, input_data): if self.preprocess_function is not None: input_data = self.preprocess_function(input_data) - onnx_input = {input_name: input_data} + onnx_input = {input_name: input_data.astype(np.float32)} pred_onnx = sess.run([output_name], onnx_input)[0] return pred_onnx From 20f0135a9a7645fc0d6ab8ee1452677166171033 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 10:33:56 +0200 Subject: [PATCH 70/92] fix keyword arguments for tabular --- dianna/dashboard/_shared.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 560d5ce6..0d4d2254 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -87,12 +87,16 @@ def _get_params(method: str, key): } elif method == 'KernelSHAP': - return { - 'nsamples': st.number_input('Number of samples', value=1000, key=f'{key}_{method}_nsamp'), - 'background': st.number_input('Background', value=0, key=f'{key}_{method}_background'), - 'n_segments': st.number_input('Number of segments', value=200, key=f'{key}_{method}_nseg'), - 'sigma': st.number_input('σ', value=0, key=f'{key}_{method}_sigma'), - } + if 'Tabular' in key: + return {'training_data_kmeans': st.number_input('Training data kmeans', value=5, key=f'{key}_{method}_training_data_kmeans'), + } + else: + return { + 'nsamples': st.number_input('Number of samples', value=1000, key=f'{key}_{method}_nsamp'), + 'background': st.number_input('Background', value=0, key=f'{key}_{method}_background'), + 'n_segments': st.number_input('Number of segments', value=200, key=f'{key}_{method}_nseg'), + 'sigma': st.number_input('σ', value=0, key=f'{key}_{method}_sigma'), + } elif method == 'LIME': return { From ef5308c958340c912d20e3ffa00eda796fc5d772 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 10:54:43 +0200 Subject: [PATCH 71/92] no kernelshap --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index fcb7d463..57779f63 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -70,7 +70,7 @@ labels = None mode = 'regression' -choices = ('RISE', 'LIME', 'KernelSHAP') +choices = ('RISE', 'LIME') st.text("") st.text("") From 556aa5f59dbabac30193e9f6aadc193c26deab0d Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 10:55:06 +0200 Subject: [PATCH 72/92] kernelshap function for tabular --- dianna/dashboard/_models_tabular.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 87f29c89..1eb32f41 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -37,19 +37,17 @@ def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs): return relevances @st.cache_data -def _run_kernelshap_tabular(_model, table, training_data, **kwargs): +def _run_kernelshap_tabular(model, table, training_data, **kwargs): # Kernelshap interface is different. Write model to temporary file. with tempfile.NamedTemporaryFile() as f: - f.write(_model) + f.write(model) f.flush() - relevances = explain_tabular( - _model, - table, - method='KernelSHAP', - training_data=training_data, - **kwargs, - ) - return relevances + relevances = explain_tabular(f.name, + table, + method='KernelSHAP', + training_data=training_data, + **kwargs) + return relevances[0] explain_tabular_dispatcher = { From 125cc16cfd4e6251b8c9546566d2b454df919e1f Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 13:44:09 +0200 Subject: [PATCH 73/92] update text --- dianna/dashboard/pages/Tabular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 57779f63..c4554c32 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -105,7 +105,7 @@ predictions=predictions[0], labels=labels) else: - st.info("Select the input data either by clicking the corresponding row in the table or input the row index above to continue.") + st.info("Select the input data by clicking a row in the table.") st.stop() st.text("") From d24075eeb2a04342c7372be4992972cb61f693f8 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 13:48:08 +0200 Subject: [PATCH 74/92] undo weather example specifics --- dianna/dashboard/_model_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index 8db3ed47..416061e8 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -7,11 +7,9 @@ def load_data(file): """Open data from a file and returns it as pandas DataFrame""" df = pd.read_csv(file, parse_dates=True) - # FIXME: only for weather example - data = df.drop(columns=['DATE', 'MONTH'])[:-1] # Add index column - data.insert(0, 'Index', data.index) - return data + df.insert(0, 'Index', df.index) + return df def preprocess_function(image): From b1133fb25a700623607ab393a03c28cf0f50dcac Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 14:02:26 +0200 Subject: [PATCH 75/92] add test for tabular --- tests/test_dashboard.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 08fb02a4..7de47bae 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -245,3 +245,28 @@ def test_timeseries_page(page: Page): "baseButton-secondary").click() page.get_by_label("Select labels").get_by_test_id( "baseButton-secondary").click() + + +def test_tabular_page(page: Page): + """Test performance of tabular page""" + page.goto(f'{BASE_URL}/Tabular') + + page.get_by_text('Running...').wait_for(state='detached') + + expect(page).to_have_title('Tabular') + + expect(page.get_by_text("Select which input type to")).to_be_visible() + + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + + # Test using your own data + page.locator("label").filter( + has_text="Use your own data").locator("div").nth(1).click() + page.get_by_label("Select input data").get_by_test_id( + "baseButton-secondary").click() + page.get_by_label("Select model").get_by_test_id( + "baseButton-secondary").click() + page.get_by_label("Select training data").get_by_test_id( + "baseButton-secondary").click() + page.get_by_label("Select labels in case of classification model").get_by_test_id( + "baseButton-secondary").click() From 38e1d705300c6dd72838af8a2e342ce346958a1c Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 15:36:49 +0200 Subject: [PATCH 76/92] fix merge issue --- dianna/dashboard/pages/Time_series.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 524de1e2..818a40c0 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -127,15 +127,12 @@ def preprocess(data): st.text("") st.text("") -with st.spinner('Predicting class'): - predictions = predict(model=serialized_model, ts_data=ts_data_predictor) - with st.container(border=True): prediction_placeholder = st.empty() methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb') with st.spinner('Predicting class'): - predictions = predict(model=serialized_model, ts_data=ts_data_model) + predictions = predict(model=serialized_model, ts_data=ts_data_predictor) with prediction_placeholder: top_indices, top_labels = _get_top_indices_and_labels( From ba1e2960b2e2fbe843ef585b6bc50a1bb4f85af8 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:00:58 +0200 Subject: [PATCH 77/92] fix typo --- dianna/dashboard/pages/Text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index c24778de..387e37a2 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -36,7 +36,7 @@ key='Text_load_example') if load_example == 'Movie sentiment': - text_input = 'The movie started out great but the ending was dissappointing' + text_input = 'The movie started out great but the ending was disappointing' text_model_file = download('movie_review_model.onnx', 'model') text_label_file = download('labels_text.txt', 'label') From fb497b9945ac9fe5f7fe76028533952eca2d435c Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:06:00 +0200 Subject: [PATCH 78/92] fix input conversion --- dianna/utils/onnx_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/utils/onnx_runner.py b/dianna/utils/onnx_runner.py index eb82fee4..2dbd9eb2 100644 --- a/dianna/utils/onnx_runner.py +++ b/dianna/utils/onnx_runner.py @@ -30,6 +30,6 @@ def __call__(self, input_data): if self.preprocess_function is not None: input_data = self.preprocess_function(input_data) - onnx_input = {input_name: input_data.astype(np.float32)} + onnx_input = {input_name: input_data} pred_onnx = sess.run([output_name], onnx_input)[0] return pred_onnx From e1b2eb2826a62724318150b6f9acc877a472241a Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:11:10 +0200 Subject: [PATCH 79/92] undo import --- dianna/utils/onnx_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dianna/utils/onnx_runner.py b/dianna/utils/onnx_runner.py index 2dbd9eb2..b6c0ea8b 100644 --- a/dianna/utils/onnx_runner.py +++ b/dianna/utils/onnx_runner.py @@ -1,5 +1,4 @@ import onnxruntime as ort -import numpy as np class SimpleModelRunner: From c53bd962e1d95790777c2d3c828e20e73435971f Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:12:42 +0200 Subject: [PATCH 80/92] delete unused imports --- dianna/dashboard/pages/Tabular.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index c4554c32..72557790 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -10,9 +10,6 @@ from _shared import _methods_checkboxes from _shared import add_sidebar_logo from _shared import reset_example -from _shared import reset_method -from dianna.utils.downloader import download -from dianna.utils.onnx_runner import SimpleModelRunner from dianna.visualization import plot_tabular from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode From ddef7bf878ea9bd70fa9be8c946f89fac1c9bc75 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:32:22 +0200 Subject: [PATCH 81/92] delete print statement --- tests/test_dashboard.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 7de47bae..0db77d67 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -117,7 +117,6 @@ def test_text_page(page: Page): page.get_by_role('img', name='0').nth(2), page.get_by_role('img', name='0').nth(3), ): - print(selector) expect(selector).to_be_visible() # Own data option From d1c5d1cbb52ed46c4f6d24384f212d42cb1efd39 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:32:41 +0200 Subject: [PATCH 82/92] check for both positive and negatice --- tests/test_dashboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 0db77d67..b16cd009 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -99,7 +99,7 @@ def test_text_page(page: Page): page.locator('label').filter(has_text='RISE').locator('span').click() page.locator('label').filter(has_text='LIME').locator('span').click() - + page.get_by_test_id("stNumberInput-StepUp").click() page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) for selector in ( From 3707fd9ecafdabf1191a7c01cf0839dc973a5241 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:36:53 +0200 Subject: [PATCH 83/92] test for two classes --- tests/test_dashboard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index b16cd009..cbbbb14a 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -153,7 +153,7 @@ def test_image_page(page: Page): page.locator('label').filter(has_text='RISE').locator('span').click() page.locator('label').filter(has_text='KernelSHAP').locator('span').click() page.locator('label').filter(has_text='LIME').locator('span').click() - + page.get_by_test_id("stNumberInput-StepUp").click() page.get_by_text('Running...').wait_for(state='detached', timeout=45_000) for selector in ( @@ -201,7 +201,7 @@ def test_timeseries_page(page: Page): page.locator('label').filter(has_text='LIME').locator('span').click() page.locator('label').filter(has_text='RISE').locator('span').click() - + page.get_by_test_id("stNumberInput-StepUp").click() page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) for selector in ( From 37d5a5d5da27319bcc025dccafd2cfb7dcb3a423 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:42:22 +0200 Subject: [PATCH 84/92] fix tests --- tests/test_dashboard.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index cbbbb14a..85653d15 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -261,11 +261,7 @@ def test_tabular_page(page: Page): # Test using your own data page.locator("label").filter( has_text="Use your own data").locator("div").nth(1).click() - page.get_by_label("Select input data").get_by_test_id( - "baseButton-secondary").click() - page.get_by_label("Select model").get_by_test_id( - "baseButton-secondary").click() - page.get_by_label("Select training data").get_by_test_id( - "baseButton-secondary").click() - page.get_by_label("Select labels in case of classification model").get_by_test_id( - "baseButton-secondary").click() + page.get_by_label("Select tabular data").get_by_test_id("baseButton-secondary").click() + page.get_by_label("Select model").get_by_test_id("baseButton-secondary").click() + page.get_by_label("Select training data").get_by_test_id("baseButton-secondary").click() + page.get_by_label("Select labels in case of").get_by_test_id("baseButton-secondary").click() From 3d5c64fa128475868d1639557223ef04cd458d57 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Thu, 22 Aug 2024 16:55:09 +0200 Subject: [PATCH 85/92] make ruff happy --- dianna/dashboard/Home.py | 1 - dianna/dashboard/_model_utils.py | 2 +- dianna/dashboard/_shared.py | 12 +++++------- dianna/dashboard/pages/Images.py | 1 + dianna/dashboard/pages/Tabular.py | 5 +++-- dianna/dashboard/pages/Time_series.py | 2 -- tests/test_dashboard.py | 2 +- 7 files changed, 11 insertions(+), 14 deletions(-) diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py index dc502278..51d5a118 100644 --- a/dianna/dashboard/Home.py +++ b/dianna/dashboard/Home.py @@ -1,6 +1,5 @@ import importlib import streamlit as st -from _shared import add_sidebar_logo from _shared import data_directory from streamlit_option_menu import option_menu diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index 416061e8..eb4afed2 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -1,7 +1,7 @@ from pathlib import Path import numpy as np -import pandas as pd import onnx +import pandas as pd def load_data(file): diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 0d4d2254..6ed408ae 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -1,7 +1,5 @@ import base64 import sys -from typing import Any -from typing import Dict from typing import Sequence import numpy as np import streamlit as st @@ -46,13 +44,12 @@ def build_markup_for_logo( def add_sidebar_logo(): - "Upload DIANNA logo to sidebar element" + """Upload DIANNA logo to sidebar element.""" st.sidebar.image(str(data_directory / 'logo.png')) def _methods_checkboxes(*, choices: Sequence, key): - """Get methods from a horizontal row of checkboxes - and the corresponding parameters.""" + """Get methods from a horizontal row of checkboxes and the corresponding parameters.""" n_choices = len(choices) methods = [] method_params = {} @@ -71,7 +68,7 @@ def _methods_checkboxes(*, choices: Sequence, key): # Put the message in the container above message_container.info('Select a method to continue') st.stop() - + return methods, method_params @@ -88,7 +85,8 @@ def _get_params(method: str, key): elif method == 'KernelSHAP': if 'Tabular' in key: - return {'training_data_kmeans': st.number_input('Training data kmeans', value=5, key=f'{key}_{method}_training_data_kmeans'), + return {'training_data_kmeans': st.number_input('Training data kmeans', value=5, + key=f'{key}_{method}_training_data_kmeans'), } else: return { diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index edab213d..7c509517 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -12,6 +12,7 @@ from dianna.utils.downloader import download from dianna.visualization import plot_image + add_sidebar_logo() st.title('Image explanation') diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 72557790..b8e6c1ad 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -13,6 +13,7 @@ from dianna.visualization import plot_tabular from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode + add_sidebar_logo() st.title('Tabular data explanation') @@ -118,7 +119,8 @@ for index, label in zip(top_indices, top_labels): index_col, *columns = st.columns(column_spec) - if mode == 'classification': index_col.markdown(f'##### Class: {label}') + if mode == 'classification': + index_col.markdown(f'##### Class: {label}') for col, method in zip(columns, methods): kwargs = method_params[method].copy() @@ -128,7 +130,6 @@ kwargs['_feature_names']=data[:1].columns.to_list() func = explain_tabular_dispatcher[method] - with col: with st.spinner(f'Running {method}'): diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 818a40c0..526f01ab 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -16,8 +16,6 @@ from dianna.visualization import plot_timeseries - - st.title('Time series explanation') add_sidebar_logo() diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 85653d15..1675e882 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -247,7 +247,7 @@ def test_timeseries_page(page: Page): def test_tabular_page(page: Page): - """Test performance of tabular page""" + """Test performance of tabular page.""" page.goto(f'{BASE_URL}/Tabular') page.get_by_text('Running...').wait_for(state='detached') From 123e2ab3ff3dc5163a25f16310777b083db8f89a Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 23 Aug 2024 08:48:37 +0200 Subject: [PATCH 86/92] increase timeout --- tests/test_dashboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 1675e882..29d8ff2c 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -254,7 +254,7 @@ def test_tabular_page(page: Page): expect(page).to_have_title('Tabular') - expect(page.get_by_text("Select which input type to")).to_be_visible() + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=10000) page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() From 7f75520bd1e8aad0be02967c38d7d872050d6487 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 23 Aug 2024 08:49:50 +0200 Subject: [PATCH 87/92] fix linting --- dianna/dashboard/_model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index eb4afed2..cc8084d0 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -5,7 +5,7 @@ def load_data(file): - """Open data from a file and returns it as pandas DataFrame""" + """Open data from a file and returns it as pandas DataFrame.""" df = pd.read_csv(file, parse_dates=True) # Add index column df.insert(0, 'Index', df.index) From b2c35907e21488246e9958558185c69088cbbdc1 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 23 Aug 2024 08:50:54 +0200 Subject: [PATCH 88/92] fix linting --- dianna/dashboard/_models_tabular.py | 1 - dianna/dashboard/pages/Images.py | 1 - dianna/dashboard/pages/Tabular.py | 1 - dianna/dashboard/pages/Time_series.py | 1 - 4 files changed, 4 deletions(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 1eb32f41..1eb2dfb2 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -4,7 +4,6 @@ from dianna import explain_tabular from dianna.utils.onnx_runner import SimpleModelRunner - @st.cache_data def predict(*, model, tabular_input): model_runner = SimpleModelRunner(model) diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 7c509517..edab213d 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -12,7 +12,6 @@ from dianna.utils.downloader import download from dianna.visualization import plot_image - add_sidebar_logo() st.title('Image explanation') diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index b8e6c1ad..6ca9c381 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -13,7 +13,6 @@ from dianna.visualization import plot_tabular from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode - add_sidebar_logo() st.title('Tabular data explanation') diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 526f01ab..07c88e15 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -15,7 +15,6 @@ from dianna.visualization import plot_image from dianna.visualization import plot_timeseries - st.title('Time series explanation') add_sidebar_logo() From 5d8f19612910753cb9aae437f746f0d2dd15ab88 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 23 Aug 2024 08:54:09 +0200 Subject: [PATCH 89/92] fix sorting --- dianna/dashboard/_models_tabular.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 1eb2dfb2..96573326 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -1,9 +1,10 @@ +import tempfile import numpy as np import streamlit as st -import tempfile from dianna import explain_tabular from dianna.utils.onnx_runner import SimpleModelRunner + @st.cache_data def predict(*, model, tabular_input): model_runner = SimpleModelRunner(model) From 3fd787d10f5f2a322650e105e930474701d2ad25 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 23 Aug 2024 08:55:13 +0200 Subject: [PATCH 90/92] fix imports --- dianna/dashboard/pages/Tabular.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index 6ca9c381..f9825648 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -4,14 +4,16 @@ from _model_utils import load_labels from _model_utils import load_model from _model_utils import load_training_data -from _models_tabular import predict from _models_tabular import explain_tabular_dispatcher +from _models_tabular import predict from _shared import _get_top_indices_and_labels from _shared import _methods_checkboxes from _shared import add_sidebar_logo from _shared import reset_example +from st_aggrid import AgGrid +from st_aggrid import GridOptionsBuilder +from st_aggrid import GridUpdateMode from dianna.visualization import plot_tabular -from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode add_sidebar_logo() From ce38ea903924bfc53d1f27eab9448949e894ce2d Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 23 Aug 2024 09:14:16 +0200 Subject: [PATCH 91/92] increase timeout? --- tests/test_dashboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 29d8ff2c..8931ed28 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -254,7 +254,7 @@ def test_tabular_page(page: Page): expect(page).to_have_title('Tabular') - expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=10000) + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=20_000) page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() From f38fd097b1ab29a755973589bbed4864c91a3df2 Mon Sep 17 00:00:00 2001 From: Laura Ootes Date: Fri, 23 Aug 2024 09:25:43 +0200 Subject: [PATCH 92/92] increase timeouts --- tests/test_dashboard.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 8931ed28..0ce296c0 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -95,7 +95,7 @@ def test_text_page(page: Page): # Movie sentiment example page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() page.get_by_text("Movie sentiment").click() - expect(page.get_by_text("Select a method to continue")).to_be_visible() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=50_000) page.locator('label').filter(has_text='RISE').locator('span').click() page.locator('label').filter(has_text='LIME').locator('span').click() @@ -148,13 +148,13 @@ def test_image_page(page: Page): page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() page.get_by_text("Hand-written digit recognition").click() - expect(page.get_by_text('Select a method to continue')).to_be_visible() + expect(page.get_by_text('Select a method to continue')).to_be_visible(timeout=100_000) page.locator('label').filter(has_text='RISE').locator('span').click() page.locator('label').filter(has_text='KernelSHAP').locator('span').click() page.locator('label').filter(has_text='LIME').locator('span').click() page.get_by_test_id("stNumberInput-StepUp").click() - page.get_by_text('Running...').wait_for(state='detached', timeout=45_000) + page.get_by_text('Running...').wait_for(state='detached', timeout=50_000) for selector in ( page.get_by_role('heading', name='RISE').get_by_text('RISE'), @@ -188,7 +188,7 @@ def test_timeseries_page(page: Page): expect(page).to_have_title('Time_series') - expect(page.get_by_text("Select which input type to")).to_be_visible() + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() expect(page.get_by_text("Select an example in the left")).to_be_visible() @@ -197,7 +197,7 @@ def test_timeseries_page(page: Page): # Test weather example page.locator("label").filter(has_text="Weather").locator("div").nth(1).click() - expect(page.get_by_text("Select a method to continue")).to_be_visible() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) page.locator('label').filter(has_text='LIME').locator('span').click() page.locator('label').filter(has_text='RISE').locator('span').click() @@ -220,7 +220,7 @@ def test_timeseries_page(page: Page): # Test FRB example page.locator("label").filter(has_text="FRB").locator("div").nth(1).click() - expect(page.get_by_text("Select a method to continue")).to_be_visible() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) page.locator('label').filter(has_text='RISE').locator('span').click() @@ -254,7 +254,7 @@ def test_tabular_page(page: Page): expect(page).to_have_title('Tabular') - expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=20_000) + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()