Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate feature importance analysis into comparing #21

Merged
merged 7 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ env
*.pkl
*.ipynb
__pycache__
.DS_Store
.DS_Store
tool_test_output.html
tool_test_output.json
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ RUN apt-get update && \

# Install Python packages
RUN pip install -U pip && \
pip install --no-cache-dir --no-compile pycaret[models]==${VERSION} && \
pip install --no-cache-dir --no-compile pycaret[analysis,models]==${VERSION} && \
pip install --no-cache-dir --no-compile explainerdashboard

# Clean up unnecessary packages
Expand Down
136 changes: 59 additions & 77 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,31 @@
import logging
import os

from feature_importance import FeatureImportanceAnalyzer

import pandas as pd

from utils import get_html_closing, get_html_template

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:

def __init__(self, input_file, target_col, output_dir, **kwargs):
def __init__(
self,
input_file,
target_col,
output_dir,
task_type,
**kwargs
):
self.exp = None # This will be set in the subclass
self.input_file = input_file
self.target_col = target_col
self.output_dir = output_dir
self.task_type = task_type
self.data = None
self.target = None
self.best_model = None
Expand All @@ -29,9 +41,21 @@ def __init__(self, input_file, target_col, output_dir, **kwargs):
def load_data(self):
LOG.info(f"Loading data from {self.input_file}")
self.data = pd.read_csv(self.input_file, sep=None, engine='python')
self.data = self.data.apply(pd.to_numeric, errors='coerce')
names = self.data.columns.to_list()
self.target = names[int(self.target_col)-1]
self.data = self.data.fillna(self.data.median(numeric_only=True))
if hasattr(self, 'missing_value_strategy'):
if self.missing_value_strategy == 'mean':
self.data = self.data.fillna(
self.data.mean(numeric_only=True))
elif self.missing_value_strategy == 'median':
self.data = self.data.fillna(
self.data.median(numeric_only=True))
elif self.missing_value_strategy == 'drop':
self.data = self.data.dropna()
else:
# Default strategy if not specified
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.data.columns = self.data.columns.str.replace('.', '_')

def setup_pycaret(self):
Expand Down Expand Up @@ -116,113 +140,71 @@ def save_html_report(self):
setup_params_table = pd.DataFrame(
list(filtered_setup_params.items()),
columns=['Parameter', 'Value'])
# Save model summary

best_model_params = pd.DataFrame(
self.best_model.get_params().items(),
columns=['Parameter', 'Value'])
best_model_params.to_csv(
os.path.join(self.output_dir, 'best_model.csv'),
index=False)

# Save comparison results
self.results.to_csv(os.path.join(
self.output_dir, "comparison_results.csv"))

# Read and encode plot images
plots_html = ""
for plot_name, plot_path in self.plots.items():
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot">
<h3>{plot_name.capitalize()}</h3>
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
<img src="data:image/png;base64,{encoded_image}"
alt="{plot_name}">
</div>
"""

# Generate HTML content
analyzer = FeatureImportanceAnalyzer(
data=self.data,
target_col=self.target_col,
task_type=self.task_type,
output_dir=self.output_dir)
feature_importance_html = analyzer.run()

html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width,
initial-scale=1.0">
<title>PyCaret Model Training Report</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f4f4f4;
}}
.container {{
max-width: 800px;
margin: auto;
background: white;
padding: 20px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
text-align: center;
color: #333;
}}
h2 {{
border-bottom: 2px solid #4CAF50;
color: #4CAF50;
padding-bottom: 5px;
}}
table {{
width: 100%;
border-collapse: collapse;
margin: 20px 0;
}}
table, th, td {{
border: 1px solid #ddd;
}}
th, td {{
padding: 8px;
text-align: left;
}}
th {{
background-color: #4CAF50;
color: white;
}}
.plot {{
text-align: center;
margin: 20px 0;
}}
.plot img {{
max-width: 100%;
height: auto;
}}
</style>
</head>
<body>
<div class="container">
<h1>PyCaret Model Training Report</h1>
{get_html_template()}
<h1>PyCaret Model Training Report</h1>
<div class="tabs">
<div class="tab" onclick="openTab(event, 'summary')">
Setup & Best Model</div>
<div class="tab" onclick="openTab(event, 'plots')">
Best Model Plots</div>
<div class="tab" onclick="openTab(event, 'feature')">
Feature Importance</div>
</div>
<div id="summary" class="tab-content">
<h2>Setup Parameters</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{setup_params_table.to_html(index=False,
header=False, classes='table')}
{setup_params_table.to_html(
index=False, header=False, classes='table')}
</table>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{best_model_params.to_html(index=False,
header=False, classes='table')}
{best_model_params.to_html(
index=False, header=False, classes='table')}
</table>
<h2>Comparison Results</h2>
<table>
{self.results.to_html(index=False,
classes='table')}
{self.results.to_html(index=False, classes='table')}
</table>
<h2>Plots</h2>
</div>
<div id="plots" class="tab-content">
<h2>Best Model Plots</h2>
{plots_html}
</div>
</body>
</html>
<div id="feature" class="tab-content">
{feature_importance_html}
</div>
{get_html_closing()}
"""

with open(os.path.join(
Expand Down
175 changes: 175 additions & 0 deletions tools/feature_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import base64
import logging
import os

import matplotlib.pyplot as plt

import pandas as pd

from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class FeatureImportanceAnalyzer:
def __init__(
self,
task_type,
output_dir,
data_path=None,
data=None,
target_col=None):

if data is not None:
self.data = data
LOG.info("Data loaded from memory")
else:
self.target_col = target_col
self.data = pd.read_csv(data_path, sep=None, engine='python')
self.data.columns = self.data.columns.str.replace('.', '_')
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.task_type = task_type
self.target = self.data.columns[int(target_col) - 1]
self.exp = ClassificationExperiment() \
if task_type == 'classification' \
else RegressionExperiment()
self.plots = {}
self.output_dir = output_dir

def setup_pycaret(self):
LOG.info("Initializing PyCaret")
setup_params = {
'target': self.target,
'session_id': 123,
'html': True,
'log_experiment': False,
'system_log': False
}
LOG.info(self.task_type)
LOG.info(self.exp)
self.exp.setup(self.data, **setup_params)

def save_coefficients(self):
model = self.exp.create_model('lr')
coef_df = pd.DataFrame({
'Feature': self.data.columns.drop(self.target),
'Coefficient': model.coef_[0]
})
coef_html = coef_df.to_html(index=False)
return coef_html

def save_tree_importance(self):
model = self.exp.create_model('rf')
importances = model.feature_importances_
feature_importances = pd.DataFrame({
'Feature': self.data.columns.drop(self.target),
'Importance': importances
}).sort_values(by='Importance', ascending=False)
plt.figure(figsize=(10, 6))
plt.barh(
feature_importances['Feature'],
feature_importances['Importance'])
plt.xlabel('Importance')
plt.title('Feature Importance (Random Forest)')
plot_path = os.path.join(
self.output_dir,
'tree_importance.png')
plt.savefig(plot_path)
plt.close()
self.plots['tree_importance'] = plot_path

def save_shap_values(self):
model = self.exp.create_model('lightgbm')
import shap
explainer = shap.Explainer(model)
shap_values = explainer.shap_values(
self.data.drop(columns=[self.target]))
shap.summary_plot(shap_values, self.data.drop(
columns=[self.target]), show=False)
plt.title('Shap (LightGBM)')
plot_path = os.path.join(
self.output_dir, 'shap_summary.png')
plt.savefig(plot_path)
plt.close()
self.plots['shap_summary'] = plot_path

def generate_feature_importance(self):
coef_html = self.save_coefficients()
self.save_tree_importance()
self.save_shap_values()
return coef_html

def encode_image_to_base64(self, img_path):
with open(img_path, 'rb') as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')

def generate_html_report(self, coef_html):
LOG.info("Generating HTML report")

# Read and encode plot images
plots_html = ""
for plot_name, plot_path in self.plots.items():
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot" id="{plot_name}">
<h2>Feature importance analysis from a
trained Random Forest</h2>
<h3>{'Use gini impurity for'
'calculating feature importance for classification'
'and Variance Reduction for regression'
if plot_name == 'tree_importance'
else 'SHAP Summary from a trained lightgbm'}</h3>
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
</div>
"""

# Generate HTML content with tabs
html_content = f"""
<h1>PyCaret Feature Importance Report</h1>

<div id="coefficients" class="tabcontent">
<h2>Coefficients (based on a trained
{'Logistic Regression'
if self.task_type == 'classification'
else 'Linear Regression'} Model)</h2>
<div>{coef_html}</div>
</div>
{plots_html}
"""

return html_content

def run(self):
LOG.info("Running feature importance analysis")
self.setup_pycaret()
coef_html = self.generate_feature_importance()
html_content = self.generate_html_report(coef_html)
LOG.info("Feature importance analysis completed")
return html_content


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Feature Importance Analysis")
parser.add_argument(
"--data_path", type=str, help="Path to the dataset")
parser.add_argument(
"--target_col", type=int,
help="Index of the target column (1-based)")
parser.add_argument(
"--task_type", type=str,
choices=["classification", "regression"],
help="Task type: classification or regression")
parser.add_argument(
"--output_dir",
type=str,
help="Directory to save the outputs")
args = parser.parse_args()

analyzer = FeatureImportanceAnalyzer(
args.data_path, args.target_col,
args.task_type, args.output_dir)
analyzer.run()
Loading