Skip to content

Commit

Permalink
integrate feature importance analysis into comparing (goeckslab#21)
Browse files Browse the repository at this point in the history
* init feature_importance

* integrate result into comparing result

* changed the title of plots

* changed for flask8

* resolved bugs and added tests for best_model.csv

* updated the test file

* clear for lint
  • Loading branch information
qchiujunhao authored Aug 3, 2024
1 parent 919f2a9 commit 59ec798
Show file tree
Hide file tree
Showing 18 changed files with 1,126 additions and 438 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ env
*.pkl
*.ipynb
__pycache__
.DS_Store
.DS_Store
tool_test_output.html
tool_test_output.json
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ RUN apt-get update && \

# Install Python packages
RUN pip install -U pip && \
pip install --no-cache-dir --no-compile pycaret[models]==${VERSION} && \
pip install --no-cache-dir --no-compile pycaret[analysis,models]==${VERSION} && \
pip install --no-cache-dir --no-compile explainerdashboard

# Clean up unnecessary packages
Expand Down
136 changes: 59 additions & 77 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,31 @@
import logging
import os

from feature_importance import FeatureImportanceAnalyzer

import pandas as pd

from utils import get_html_closing, get_html_template

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:

def __init__(self, input_file, target_col, output_dir, **kwargs):
def __init__(
self,
input_file,
target_col,
output_dir,
task_type,
**kwargs
):
self.exp = None # This will be set in the subclass
self.input_file = input_file
self.target_col = target_col
self.output_dir = output_dir
self.task_type = task_type
self.data = None
self.target = None
self.best_model = None
Expand All @@ -29,9 +41,21 @@ def __init__(self, input_file, target_col, output_dir, **kwargs):
def load_data(self):
LOG.info(f"Loading data from {self.input_file}")
self.data = pd.read_csv(self.input_file, sep=None, engine='python')
self.data = self.data.apply(pd.to_numeric, errors='coerce')
names = self.data.columns.to_list()
self.target = names[int(self.target_col)-1]
self.data = self.data.fillna(self.data.median(numeric_only=True))
if hasattr(self, 'missing_value_strategy'):
if self.missing_value_strategy == 'mean':
self.data = self.data.fillna(
self.data.mean(numeric_only=True))
elif self.missing_value_strategy == 'median':
self.data = self.data.fillna(
self.data.median(numeric_only=True))
elif self.missing_value_strategy == 'drop':
self.data = self.data.dropna()
else:
# Default strategy if not specified
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.data.columns = self.data.columns.str.replace('.', '_')

def setup_pycaret(self):
Expand Down Expand Up @@ -116,113 +140,71 @@ def save_html_report(self):
setup_params_table = pd.DataFrame(
list(filtered_setup_params.items()),
columns=['Parameter', 'Value'])
# Save model summary

best_model_params = pd.DataFrame(
self.best_model.get_params().items(),
columns=['Parameter', 'Value'])
best_model_params.to_csv(
os.path.join(self.output_dir, 'best_model.csv'),
index=False)

# Save comparison results
self.results.to_csv(os.path.join(
self.output_dir, "comparison_results.csv"))

# Read and encode plot images
plots_html = ""
for plot_name, plot_path in self.plots.items():
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot">
<h3>{plot_name.capitalize()}</h3>
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
<img src="data:image/png;base64,{encoded_image}"
alt="{plot_name}">
</div>
"""

# Generate HTML content
analyzer = FeatureImportanceAnalyzer(
data=self.data,
target_col=self.target_col,
task_type=self.task_type,
output_dir=self.output_dir)
feature_importance_html = analyzer.run()

html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width,
initial-scale=1.0">
<title>PyCaret Model Training Report</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f4f4f4;
}}
.container {{
max-width: 800px;
margin: auto;
background: white;
padding: 20px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
text-align: center;
color: #333;
}}
h2 {{
border-bottom: 2px solid #4CAF50;
color: #4CAF50;
padding-bottom: 5px;
}}
table {{
width: 100%;
border-collapse: collapse;
margin: 20px 0;
}}
table, th, td {{
border: 1px solid #ddd;
}}
th, td {{
padding: 8px;
text-align: left;
}}
th {{
background-color: #4CAF50;
color: white;
}}
.plot {{
text-align: center;
margin: 20px 0;
}}
.plot img {{
max-width: 100%;
height: auto;
}}
</style>
</head>
<body>
<div class="container">
<h1>PyCaret Model Training Report</h1>
{get_html_template()}
<h1>PyCaret Model Training Report</h1>
<div class="tabs">
<div class="tab" onclick="openTab(event, 'summary')">
Setup & Best Model</div>
<div class="tab" onclick="openTab(event, 'plots')">
Best Model Plots</div>
<div class="tab" onclick="openTab(event, 'feature')">
Feature Importance</div>
</div>
<div id="summary" class="tab-content">
<h2>Setup Parameters</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{setup_params_table.to_html(index=False,
header=False, classes='table')}
{setup_params_table.to_html(
index=False, header=False, classes='table')}
</table>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{best_model_params.to_html(index=False,
header=False, classes='table')}
{best_model_params.to_html(
index=False, header=False, classes='table')}
</table>
<h2>Comparison Results</h2>
<table>
{self.results.to_html(index=False,
classes='table')}
{self.results.to_html(index=False, classes='table')}
</table>
<h2>Plots</h2>
</div>
<div id="plots" class="tab-content">
<h2>Best Model Plots</h2>
{plots_html}
</div>
</body>
</html>
<div id="feature" class="tab-content">
{feature_importance_html}
</div>
{get_html_closing()}
"""

with open(os.path.join(
Expand Down
175 changes: 175 additions & 0 deletions tools/feature_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import base64
import logging
import os

import matplotlib.pyplot as plt

import pandas as pd

from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class FeatureImportanceAnalyzer:
def __init__(
self,
task_type,
output_dir,
data_path=None,
data=None,
target_col=None):

if data is not None:
self.data = data
LOG.info("Data loaded from memory")
else:
self.target_col = target_col
self.data = pd.read_csv(data_path, sep=None, engine='python')
self.data.columns = self.data.columns.str.replace('.', '_')
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.task_type = task_type
self.target = self.data.columns[int(target_col) - 1]
self.exp = ClassificationExperiment() \
if task_type == 'classification' \
else RegressionExperiment()
self.plots = {}
self.output_dir = output_dir

def setup_pycaret(self):
LOG.info("Initializing PyCaret")
setup_params = {
'target': self.target,
'session_id': 123,
'html': True,
'log_experiment': False,
'system_log': False
}
LOG.info(self.task_type)
LOG.info(self.exp)
self.exp.setup(self.data, **setup_params)

def save_coefficients(self):
model = self.exp.create_model('lr')
coef_df = pd.DataFrame({
'Feature': self.data.columns.drop(self.target),
'Coefficient': model.coef_[0]
})
coef_html = coef_df.to_html(index=False)
return coef_html

def save_tree_importance(self):
model = self.exp.create_model('rf')
importances = model.feature_importances_
feature_importances = pd.DataFrame({
'Feature': self.data.columns.drop(self.target),
'Importance': importances
}).sort_values(by='Importance', ascending=False)
plt.figure(figsize=(10, 6))
plt.barh(
feature_importances['Feature'],
feature_importances['Importance'])
plt.xlabel('Importance')
plt.title('Feature Importance (Random Forest)')
plot_path = os.path.join(
self.output_dir,
'tree_importance.png')
plt.savefig(plot_path)
plt.close()
self.plots['tree_importance'] = plot_path

def save_shap_values(self):
model = self.exp.create_model('lightgbm')
import shap
explainer = shap.Explainer(model)
shap_values = explainer.shap_values(
self.data.drop(columns=[self.target]))
shap.summary_plot(shap_values, self.data.drop(
columns=[self.target]), show=False)
plt.title('Shap (LightGBM)')
plot_path = os.path.join(
self.output_dir, 'shap_summary.png')
plt.savefig(plot_path)
plt.close()
self.plots['shap_summary'] = plot_path

def generate_feature_importance(self):
coef_html = self.save_coefficients()
self.save_tree_importance()
self.save_shap_values()
return coef_html

def encode_image_to_base64(self, img_path):
with open(img_path, 'rb') as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')

def generate_html_report(self, coef_html):
LOG.info("Generating HTML report")

# Read and encode plot images
plots_html = ""
for plot_name, plot_path in self.plots.items():
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot" id="{plot_name}">
<h2>Feature importance analysis from a
trained Random Forest</h2>
<h3>{'Use gini impurity for'
'calculating feature importance for classification'
'and Variance Reduction for regression'
if plot_name == 'tree_importance'
else 'SHAP Summary from a trained lightgbm'}</h3>
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
</div>
"""

# Generate HTML content with tabs
html_content = f"""
<h1>PyCaret Feature Importance Report</h1>
<div id="coefficients" class="tabcontent">
<h2>Coefficients (based on a trained
{'Logistic Regression'
if self.task_type == 'classification'
else 'Linear Regression'} Model)</h2>
<div>{coef_html}</div>
</div>
{plots_html}
"""

return html_content

def run(self):
LOG.info("Running feature importance analysis")
self.setup_pycaret()
coef_html = self.generate_feature_importance()
html_content = self.generate_html_report(coef_html)
LOG.info("Feature importance analysis completed")
return html_content


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Feature Importance Analysis")
parser.add_argument(
"--data_path", type=str, help="Path to the dataset")
parser.add_argument(
"--target_col", type=int,
help="Index of the target column (1-based)")
parser.add_argument(
"--task_type", type=str,
choices=["classification", "regression"],
help="Task type: classification or regression")
parser.add_argument(
"--output_dir",
type=str,
help="Directory to save the outputs")
args = parser.parse_args()

analyzer = FeatureImportanceAnalyzer(
args.data_path, args.target_col,
args.task_type, args.output_dir)
analyzer.run()
Loading

0 comments on commit 59ec798

Please sign in to comment.