Replies: 1 comment
-
I tried to test it with a case study from AI4S2S project. Here is the notebook https://github.com/AI4S2S/cookbook/blob/test_xai_lime/workflow/xai/pred_temperature_LSTM_xai_LIME.ipynb . In this test case, a LSTM model is trained to predict surface temperature over land with sea surface temperature. The data used here is from ERA5, which has spatio-temporal structure (3D data, time x lat x lon). Originally, the model was trained to perform sequence-to-sequence prediction. But since LIME can only receive one instance (timestep) each time, I have to retrain the model to generate one-to-one prediction. This also reflects the difference between predictions for timeseries data and tabular data, which somehow limits the usage. Since LIME requests input to be 2D (sequence x channels/variables), the spatial dimension (lat x lon) needs to be flatten. If the model is trained to work with 3D spatio-temporal data, then a runner function needs to be created to make the LIME explainer API happy. But this is easy to solve. So in general, if we adopt LIME tabular API in dianna, it will work for regression models, but only with one-to-one prediction case. (while, it is possible to further flatten and unfold the temporal dimension with spatial dimension, in other words, 3D to 1D unfolding, but this seems to be quite a hassle. I didn't try it. Perhaps we could play a bit to see if the explanations are still acceptable). |
Beta Was this translation helpful? Give feedback.
-
I spent some time testing LIME for regression models. In LIME (the original implementation), it seems only the tabular explainer supports regression model. They provide a simple regression example in a notebook (https://marcotcr.github.io/lime/tutorials/Using%2Blime%2Bfor%2Bregression.html).
In this example they trained a random forest regressor (with
scikit-learn
) to predict one variable from multi-variate tabular data. It looks quite simple and its API is also quite straight forward, which seems to be easy if we want to adopt it indianna
to support tabular data.(Note: in their notebook they use "Boston housing price" dataset from
scikit-learn
. But this dataset has been deprecated. It is replaced by "california housing price". Instead of usingboston = load_boston()
, just callcalifornia = fetch_california_housing()
. The rest will be the same.)Beta Was this translation helpful? Give feedback.
All reactions