Skip to content

Commit

Permalink
Integrate explainer plots into report (#25)
Browse files Browse the repository at this point in the history
* init commit

* integrated explainer into classification report

* finished explainer into report

* changed test file
  • Loading branch information
qchiujunhao authored Sep 27, 2024
1 parent 59ec798 commit 9a3b9e1
Show file tree
Hide file tree
Showing 10 changed files with 433 additions and 977 deletions.
83 changes: 81 additions & 2 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pandas as pd

# from sklearn.metrics import auc, precision_recall_curve

from utils import get_html_closing, get_html_template

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -31,7 +33,11 @@ def __init__(
self.target = None
self.best_model = None
self.results = None
self.features_name = None
self.plots = {}
self.expaliner = None
self.plots_explainer_html = None
self.trees = []
for key, value in kwargs.items():
setattr(self, key, value)
self.setup_params = {}
Expand All @@ -43,7 +49,11 @@ def load_data(self):
self.data = pd.read_csv(self.input_file, sep=None, engine='python')
self.data = self.data.apply(pd.to_numeric, errors='coerce')
names = self.data.columns.to_list()
self.target = names[int(self.target_col)-1]
target_index = int(self.target_col)-1
self.target = names[target_index]
self.features_name = [name
for i, name in enumerate(names)
if i != target_index]
if hasattr(self, 'missing_value_strategy'):
if self.missing_value_strategy == 'mean':
self.data = self.data.fillna(
Expand Down Expand Up @@ -110,13 +120,27 @@ def setup_pycaret(self):

def train_model(self):
LOG.info("Training and selecting the best model")
# all_models = None
if hasattr(self, 'models') and self.models is not None:
self.best_model = self.exp.compare_models(
include=self.models)
else:
self.best_model = self.exp.compare_models()
# self.best_model = all_models[0]
self.results = self.exp.pull()

# pr_auc_list = []
# for model in all_models:
# y_pred_prob = self.exp.predict_model(model)
# precision, recall, _ = precision_recall_curve(
# y_pred_prob['actual'], y_pred_prob['Score'])

# pr_auc = auc(recall, precision)
# pr_auc_list.append(pr_auc)

# self.results['PR-AUC'] = pr_auc_list
# self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True)

def save_model(self):
LOG.info("Saving the model")
self.exp.save_model(self.best_model, "model")
Expand Down Expand Up @@ -161,6 +185,18 @@ def save_html_report(self):
</div>
"""

tree_plots = ""
for i, tree in enumerate(self.trees):
if tree:
tree_plots += f"""
<div class="plot">
<h3>Tree {i+1}</h3>
<img src="data:image/png;base64,
{tree}"
alt="tree {i+1}">
</div>
"""

analyzer = FeatureImportanceAnalyzer(
data=self.data,
target_col=self.target_col,
Expand All @@ -178,6 +214,9 @@ def save_html_report(self):
Best Model Plots</div>
<div class="tab" onclick="openTab(event, 'feature')">
Feature Importance</div>
<div class="tab" onclick="openTab(event, 'explainer')">
Explainer
</div>
</div>
<div id="summary" class="tab-content">
<h2>Setup Parameters</h2>
Expand All @@ -204,6 +243,10 @@ def save_html_report(self):
<div id="feature" class="tab-content">
{feature_importance_html}
</div>
<div id="explainer" class="tab-content">
{self.plots_explainer_html}
{tree_plots}
</div>
{get_html_closing()}
"""

Expand All @@ -214,11 +257,47 @@ def save_html_report(self):
def save_dashboard(self):
raise NotImplementedError("Subclasses should implement this method")

def generate_plots_explainer(self):
raise NotImplementedError("Subclasses should implement this method")

# not working now
def generate_tree_plots(self):
from sklearn.ensemble import RandomForestClassifier, \
RandomForestRegressor
from xgboost import XGBClassifier, XGBRegressor
from explainerdashboard.explainers import RandomForestExplainer

LOG.info("Generating tree plots")
X_test = self.exp.X_test_transformed.copy()
y_test = self.exp.y_test_transformed

is_rf = isinstance(self.best_model, RandomForestClassifier) or \
isinstance(self.best_model, RandomForestRegressor)

is_xgb = isinstance(self.best_model, XGBClassifier) or \
isinstance(self.best_model, XGBRegressor)

try:
if is_rf:
num_trees = self.best_model.n_estimators
if is_xgb:
num_trees = len(self.best_model.get_booster().get_dump())
explainer = RandomForestExplainer(self.best_model, X_test, y_test)
for i in range(num_trees):
fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
LOG.info(f"Tree {i+1}")
LOG.info(fig)
self.trees.append(fig)
except Exception as e:
LOG.error(f"Error generating tree plots: {e}")

def run(self):
self.load_data()
self.setup_pycaret()
self.train_model()
self.save_model()
self.generate_plots()
self.generate_plots_explainer()
self.generate_tree_plots()
self.save_html_report()
self.save_dashboard()
# self.save_dashboard()
41 changes: 17 additions & 24 deletions tools/feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def setup_pycaret(self):
LOG.info(self.exp)
self.exp.setup(self.data, **setup_params)

def save_coefficients(self):
model = self.exp.create_model('lr')
coef_df = pd.DataFrame({
'Feature': self.data.columns.drop(self.target),
'Coefficient': model.coef_[0]
})
coef_html = coef_df.to_html(index=False)
return coef_html
# def save_coefficients(self):
# model = self.exp.create_model('lr')
# coef_df = pd.DataFrame({
# 'Feature': self.data.columns.drop(self.target),
# 'Coefficient': model.coef_[0]
# })
# coef_html = coef_df.to_html(index=False)
# return coef_html

def save_tree_importance(self):
model = self.exp.create_model('rf')
Expand Down Expand Up @@ -96,16 +96,15 @@ def save_shap_values(self):
self.plots['shap_summary'] = plot_path

def generate_feature_importance(self):
coef_html = self.save_coefficients()
# coef_html = self.save_coefficients()
self.save_tree_importance()
self.save_shap_values()
return coef_html

def encode_image_to_base64(self, img_path):
with open(img_path, 'rb') as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')

def generate_html_report(self, coef_html):
def generate_html_report(self):
LOG.info("Generating HTML report")

# Read and encode plot images
Expand All @@ -114,13 +113,15 @@ def generate_html_report(self, coef_html):
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot" id="{plot_name}">
<h2>Feature importance analysis from a
trained Random Forest</h2>
<h2>{'Feature importance analysis from a'
'trained Random Forest'
if plot_name == 'tree_importance'
else 'SHAP Summary from a trained lightgbm'}</h2>
<h3>{'Use gini impurity for'
'calculating feature importance for classification'
'and Variance Reduction for regression'
if plot_name == 'tree_importance'
else 'SHAP Summary from a trained lightgbm'}</h3>
else ''}</h3>
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
</div>
Expand All @@ -129,14 +130,6 @@ def generate_html_report(self, coef_html):
# Generate HTML content with tabs
html_content = f"""
<h1>PyCaret Feature Importance Report</h1>
<div id="coefficients" class="tabcontent">
<h2>Coefficients (based on a trained
{'Logistic Regression'
if self.task_type == 'classification'
else 'Linear Regression'} Model)</h2>
<div>{coef_html}</div>
</div>
{plots_html}
"""

Expand All @@ -145,8 +138,8 @@ def generate_html_report(self, coef_html):
def run(self):
LOG.info("Running feature importance analysis")
self.setup_pycaret()
coef_html = self.generate_feature_importance()
html_content = self.generate_html_report(coef_html)
self.generate_feature_importance()
html_content = self.generate_html_report()
LOG.info("Feature importance analysis completed")
return html_content

Expand Down
129 changes: 129 additions & 0 deletions tools/pycaret_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from pycaret.classification import ClassificationExperiment

from utils import add_plot_to_html

LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -41,3 +43,130 @@ def generate_plots(self):
except Exception as e:
LOG.error(f"Error generating plot {plot_name}: {e}")
continue

def generate_plots_explainer(self):
LOG.info("Generating and saving plots from explainer")

from explainerdashboard import ClassifierExplainer

X_test = self.exp.X_test_transformed.copy()
y_test = self.exp.y_test_transformed

explainer = ClassifierExplainer(self.best_model, X_test, y_test)
self.expaliner = explainer
plots_explainer_html = ""

try:
fig_importance = explainer.plot_importances()
plots_explainer_html += add_plot_to_html(fig_importance)
except Exception as e:
LOG.error(f"Error generating plot importance(mean shap): {e}")

try:
fig_importance_perm = explainer.plot_importances(
kind="permutation")
plots_explainer_html += add_plot_to_html(fig_importance_perm)
except Exception as e:
LOG.error(f"Error generating plot importance(permutation): {e}")

# Uncomment and adjust if needed
# try:
# fig_shap = explainer.plot_shap_summary()
# plots_explainer_html += add_plot_to_html(fig_shap,
# include_plotlyjs=False)
# except Exception as e:
# LOG.error(f"Error generating plot shap: {e}")

# try:
# fig_contributions = explainer.plot_contributions(
# index=0)
# plots_explainer_html += add_plot_to_html(
# fig_contributions, include_plotlyjs=False)
# except Exception as e:
# LOG.error(f"Error generating plot contributions: {e}")

# try:
# for feature in self.features_name:
# fig_dependence = explainer.plot_dependence(col=feature)
# plots_explainer_html += add_plot_to_html(fig_dependence)
# except Exception as e:
# LOG.error(f"Error generating plot dependencies: {e}")

try:
for feature in self.features_name:
fig_pdp = explainer.plot_pdp(feature)
plots_explainer_html += add_plot_to_html(fig_pdp)
except Exception as e:
LOG.error(f"Error generating plot pdp: {e}")

try:
for feature in self.features_name:
fig_interaction = explainer.plot_interaction(
col=feature, interaction_col=feature)
plots_explainer_html += add_plot_to_html(fig_interaction)
except Exception as e:
LOG.error(f"Error generating plot interactions: {e}")

try:
for feature in self.features_name:
fig_interactions_importance = \
explainer.plot_interactions_importance(
col=feature)
plots_explainer_html += add_plot_to_html(
fig_interactions_importance)
except Exception as e:
LOG.error(f"Error generating plot interactions importance: {e}")

# try:
# for feature in self.features_name:
# fig_interactions_detailed = \
# explainer.plot_interactions_detailed(
# col=feature)
# plots_explainer_html += add_plot_to_html(
# fig_interactions_detailed)
# except Exception as e:
# LOG.error(f"Error generating plot interactions detailed: {e}")

try:
fig_precision = explainer.plot_precision()
plots_explainer_html += add_plot_to_html(fig_precision)
except Exception as e:
LOG.error(f"Error generating plot precision: {e}")

try:
fig_cumulative_precision = explainer.plot_cumulative_precision()
plots_explainer_html += add_plot_to_html(fig_cumulative_precision)
except Exception as e:
LOG.error(f"Error generating plot cumulative precision: {e}")

try:
fig_classification = explainer.plot_classification()
plots_explainer_html += add_plot_to_html(fig_classification)
except Exception as e:
LOG.error(f"Error generating plot classification: {e}")

try:
fig_confusion_matrix = explainer.plot_confusion_matrix()
plots_explainer_html += add_plot_to_html(fig_confusion_matrix)
except Exception as e:
LOG.error(f"Error generating plot confusion matrix: {e}")

try:
fig_lift_curve = explainer.plot_lift_curve()
plots_explainer_html += add_plot_to_html(fig_lift_curve)
except Exception as e:
LOG.error(f"Error generating plot lift curve: {e}")

try:
fig_roc_auc = explainer.plot_roc_auc()
plots_explainer_html += add_plot_to_html(fig_roc_auc)
except Exception as e:
LOG.error(f"Error generating plot roc auc: {e}")

try:
fig_pr_auc = explainer.plot_pr_auc()
plots_explainer_html += add_plot_to_html(fig_pr_auc)
except Exception as e:
LOG.error(f"Error generating plot pr auc: {e}")

self.plots_explainer_html = plots_explainer_html
Loading

0 comments on commit 9a3b9e1

Please sign in to comment.