diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index 1f538ead..94c55bef 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -19,29 +19,40 @@ st.sidebar.header('Input data') -load_example = st.sidebar.checkbox('Load example data', - key='Image_example_check') +load_example_digits = st.sidebar.checkbox('Load hand-written digits example', + key='Image_digits_example_check') image_file = st.sidebar.file_uploader('Select image', type=('png', 'jpg', 'jpeg'), - disabled=load_example) + disabled=load_example_digits) if image_file: st.sidebar.image(image_file) image_model_file = st.sidebar.file_uploader('Select model', type='onnx', - disabled=load_example) + disabled=load_example_digits) image_label_file = st.sidebar.file_uploader('Select labels', type='txt', - disabled=load_example) + disabled=load_example_digits) -if load_example: +if load_example_digits: image_file = (data_directory / 'digit0.jpg') image_model_file = (model_directory / 'mnist_model_tf.onnx') image_label_file = (label_directory / 'labels_mnist.txt') + st.markdown( + """ + This example demonstrates the use of DIANNA on a pretrained binary + [MNIST](https://yann.lecun.com/exdb/mnist/) model using hand-written + digit images. The model predicts for an image of a hand-written 0 or 1, + which of the two it most likely is. This example visualizes the + relevance attributions for each pixel/super-pixel by displaying them on + top of the input image. + """ + ) + if not (image_file and image_model_file and image_label_file): st.info('Add your input data in the left panel to continue') st.stop() diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index a426437c..2de02a7c 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -18,27 +18,36 @@ st.sidebar.header('Input data') -load_example = st.sidebar.checkbox('Load example data', - key='Text_example_check') +load_example_moviesentiment = st.sidebar.checkbox('Load movie sentiment example', + key='Text_example_check_moviesentiment') -text_input = st.sidebar.text_input('Input string', disabled=load_example) +text_input = st.sidebar.text_input('Input string', disabled=load_example_moviesentiment) if text_input: st.sidebar.write(text_input) text_model_file = st.sidebar.file_uploader('Select model', type='onnx', - disabled=load_example) + disabled=load_example_moviesentiment) text_label_file = st.sidebar.file_uploader('Select labels', type='txt', - disabled=load_example) + disabled=load_example_moviesentiment) -if load_example: +if load_example_moviesentiment: text_input = 'The movie started out great but the ending was dissappointing' text_model_file = model_directory / 'movie_review_model.onnx' text_label_file = label_directory / 'labels_text.txt' + st.markdown( + """ + This example demonstrates the use of DIANNA on the [Stanford Sentiment + Treebank dataset](https://nlp.stanford.edu/sentiment/index.html) which + contains one-sentence movie reviews. A pre-trained neural network + classifier is used, which identifies whether a movie review is positive + or negative. + """) + if not (text_input and text_model_file and text_label_file): st.info('Add your input data in the left panel to continue') st.stop() diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 4e45f576..d6e7d937 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -20,26 +20,38 @@ st.sidebar.header('Input data') -load_example = st.sidebar.checkbox('Load example data', key='TS_example_check') +load_example_weather = st.sidebar.checkbox('Load weather example', key='TS_weather_example_check') ts_file = st.sidebar.file_uploader('Select input data', type='npy', - disabled=load_example) + disabled=load_example_weather) ts_model_file = st.sidebar.file_uploader('Select model', type='onnx', - disabled=load_example) + disabled=load_example_weather) ts_label_file = st.sidebar.file_uploader('Select labels', type='txt', - disabled=load_example) + disabled=load_example_weather) -if load_example: +if load_example_weather: ts_file = (data_directory / 'weather_data.npy') ts_model_file = (model_directory / 'season_prediction_model_temp_max_binary.onnx') ts_label_file = (label_directory / 'weather_data_labels.txt') + st.markdown( + """This example demonstrates the use of DIANNA + on a pre-trained binary classification model for season prediction. The + input data is the [weather prediction + dataset](https://zenodo.org/records/5071376). This classification model + uses time (days) as function of mean temperature to predict if the whole + time series is either summer or winter. Using a chosen XAI method the + relevance scores are displayed on top of the timeseries. The days + contributing positively towards the classification decision are + indicated in red and those who contribute negatively in blue. + """) + if not (ts_file and ts_model_file and ts_label_file): st.info('Add your input data in the left panel to continue') st.stop() diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 9c5171fe..3fe76ae0 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -98,7 +98,7 @@ def test_text_page(page: Page): expect(selector).to_be_visible(timeout=30_000) page.locator('label').filter( - has_text='Load example data').locator('span').click() + has_text='Load movie sentiment example').locator('span').click() expect(page.get_by_text('Select a method to continue')).to_be_visible() @@ -139,7 +139,7 @@ def test_image_page(page: Page): ).to_be_visible(timeout=100_000) page.locator('label').filter( - has_text='Load example data').locator('span').click() + has_text='Load hand-written digits example').locator('span').click() expect(page.get_by_text('Select a method to continue')).to_be_visible() @@ -180,7 +180,7 @@ def test_timeseries_page(page: Page): ).to_be_visible() page.locator('label').filter( - has_text='Load example data').locator('span').click() + has_text='Load weather example').locator('span').click() expect(page.get_by_text('Select a method to continue')).to_be_visible()