Skip to content

Commit

Permalink
changed for flask8
Browse files Browse the repository at this point in the history
  • Loading branch information
qchiujunhao committed Aug 1, 2024
1 parent 51beaf4 commit 8993219
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 40 deletions.
53 changes: 37 additions & 16 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from feature_importance import FeatureImportanceAnalyzer

from utils import get_html_template, get_html_closing

import pandas as pd

from utils import get_html_closing, get_html_template

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,44 +113,64 @@ def save_html_report(self):

model_name = type(self.best_model).__name__
excluded_params = ['html', 'log_experiment', 'system_log']
filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params}
setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value'])

best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)
self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
filtered_setup_params = {
k: v
for k, v in self.setup_params.items() if k not in excluded_params
}
setup_params_table = pd.DataFrame(
list(filtered_setup_params.items()),
columns=['Parameter', 'Value'])

best_model_params = pd.DataFrame(
self.best_model.get_params().items(),
columns=['Parameter', 'Value'])
best_model_params.to_csv(
os.path.join(self.output_dir, 'best_model.csv'),
index=False)
self.results.to_csv(os.path.join(
self.output_dir, "comparison_results.csv"))

plots_html = ""
for plot_name, plot_path in self.plots.items():
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot">
<h3>{plot_name.capitalize()}</h3>
<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
<img src="data:image/png;base64,{encoded_image}"
alt="{plot_name}">
</div>
"""

analyzer = FeatureImportanceAnalyzer(data=self.data, target_col=self.target_col, task_type='classification', output_dir=self.output_dir)
analyzer = FeatureImportanceAnalyzer(
data=self.data,
target_col=self.target_col,
task_type='classification',
output_dir=self.output_dir)
feature_importance_html = analyzer.run()

html_content = f"""
{get_html_template()}
<h1>PyCaret Model Training Report</h1>
<div class="tabs">
<div class="tab" onclick="openTab(event, 'summary')">Setup & Best Model</div>
<div class="tab" onclick="openTab(event, 'plots')">Best Model Plots</div>
<div class="tab" onclick="openTab(event, 'feature')">Feature Importance</div>
<div class="tab" onclick="openTab(event, 'summary')">
Setup & Best Model</div>
<div class="tab" onclick="openTab(event, 'plots')">
Best Model Plots</div>
<div class="tab" onclick="openTab(event, 'feature')">
Feature Importance</div>
</div>
<div id="summary" class="tab-content">
<h2>Setup Parameters</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{setup_params_table.to_html(index=False, header=False, classes='table')}
{setup_params_table.to_html(
index=False, header=False, classes='table')}
</table>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{best_model_params.to_html(index=False, header=False, classes='table')}
{best_model_params.to_html(
index=False, header=False, classes='table')}
</table>
<h2>Comparison Results</h2>
<table>
Expand All @@ -167,7 +187,8 @@ def save_html_report(self):
{get_html_closing()}
"""

with open(os.path.join(self.output_dir, "comparison_result.html"), "w") as file:
with open(os.path.join(
self.output_dir, "comparison_result.html"), "w") as file:
file.write(html_content)

def save_dashboard(self):
Expand Down
78 changes: 55 additions & 23 deletions tools/feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@
import logging
import os

import pandas as pd

import matplotlib.pyplot as plt

from pycaret.classification import ClassificationExperiment
import pandas as pd

from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class FeatureImportanceAnalyzer:
def __init__(self, task_type, output_dir, data_path=None, data=None, target_col=None):
def __init__(
self,
task_type,
output_dir,
data_path=None,
data=None,
target_col=None):

if data is not None:
self.data = data
else:
Expand All @@ -25,7 +31,9 @@ def __init__(self, task_type, output_dir, data_path=None, data=None, target_col
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.task_type = task_type
self.target = self.data.columns[int(target_col) - 1]
self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
self.exp = ClassificationExperiment() \
if task_type == 'classification' \
else RegressionExperiment()
self.plots = {}
self.output_dir = output_dir

Expand Down Expand Up @@ -57,10 +65,14 @@ def save_tree_importance(self):
'Importance': importances
}).sort_values(by='Importance', ascending=False)
plt.figure(figsize=(10, 6))
plt.barh(feature_importances['Feature'], feature_importances['Importance'])
plt.barh(
feature_importances['Feature'],
feature_importances['Importance'])
plt.xlabel('Importance')
plt.title('Feature Importance (Random Forest)')
plot_path = os.path.join(self.output_dir, 'tree_importance.png')
plot_path = os.path.join(
self.output_dir,
'tree_importance.png')
plt.savefig(plot_path)
plt.close()
self.plots['tree_importance'] = plot_path
Expand All @@ -69,10 +81,13 @@ def save_shap_values(self):
model = self.exp.create_model('lightgbm')
import shap
explainer = shap.Explainer(model)
shap_values = explainer.shap_values(self.data.drop(columns=[self.target]))
shap.summary_plot(shap_values, self.data.drop(columns=[self.target]), show=False)
shap_values = explainer.shap_values(
self.data.drop(columns=[self.target]))
shap.summary_plot(shap_values, self.data.drop(
columns=[self.target]), show=False)
plt.title('Shap (LightGBM)')
plot_path = os.path.join(self.output_dir, 'shap_summary.png')
plot_path = os.path.join(
self.output_dir, 'shap_summary.png')
plt.savefig(plot_path)
plt.close()
self.plots['shap_summary'] = plot_path
Expand All @@ -86,7 +101,7 @@ def generate_feature_importance(self):
def encode_image_to_base64(self, img_path):
with open(img_path, 'rb') as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')

def generate_html_report(self, coef_html):
LOG.info("Generating HTML report")

Expand All @@ -96,12 +111,15 @@ def generate_html_report(self, coef_html):
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot" id="{plot_name}">
<h2>Feature importance analysis from a trained Random Forest</h2>
<h3>{'Use gini impurity for calculating feature importance for classification'
'and Variance Reduction for regression'
if plot_name == 'tree_importance'
<h2>Feature importance analysis from a
trained Random Forest</h2>
<h3>{'Use gini impurity for'
'calculating feature importance for classification'
'and Variance Reduction for regression'
if plot_name == 'tree_importance'
else 'SHAP Summary from a trained lightgbm'}</h3>
<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
</div>
"""

Expand All @@ -110,8 +128,10 @@ def generate_html_report(self, coef_html):
<h1>PyCaret Feature Importance Report</h1>
<div id="coefficients" class="tabcontent">
<h2>Coefficients (based on a trained
{'Logistic Regression' if self.task_type == 'classification' else 'Linear Regression'} Model)</h2>
<h2>Coefficients (based on a trained
{'Logistic Regression'
if self.task_type == 'classification'
else 'Linear Regression'} Model)</h2>
<div>{coef_html}</div>
</div>
{plots_html}
Expand All @@ -127,14 +147,26 @@ def run(self):
LOG.info("Feature importance analysis completed")
return html_content


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Feature Importance Analysis")
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument("--target_col", type=int, help="Index of the target column (1-based)")
parser.add_argument("--task_type", type=str, choices=["classification", "regression"], help="Task type: classification or regression")
parser.add_argument("--output_dir", type=str, help="Directory to save the outputs")
parser.add_argument(
"--data_path", type=str, help="Path to the dataset")
parser.add_argument(
"--target_col", type=int,
help="Index of the target column (1-based)")
parser.add_argument(
"--task_type", type=str,
choices=["classification", "regression"],
help="Task type: classification or regression")
parser.add_argument(
"--output_dir",
type=str,
help="Directory to save the outputs")
args = parser.parse_args()

analyzer = FeatureImportanceAnalyzer(args.data_path, args.target_col, args.task_type, args.output_dir)
analyzer = FeatureImportanceAnalyzer(
args.data_path, args.target_col,
args.task_type, args.output_dir)
analyzer.run()
4 changes: 3 additions & 1 deletion tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def get_html_template():
<div class="container">
"""


def get_html_closing():
return """
</div>
Expand All @@ -96,7 +97,8 @@ def get_html_closing():
}}
tablinks = document.getElementsByClassName("tab");
for (i = 0; i < tablinks.length; i++) {{
tablinks[i].className = tablinks[i].className.replace(" active-tab", "");
tablinks[i].className =
tablinks[i].className.replace(" active-tab", "");
}}
document.getElementById(tabName).style.display = "block";
evt.currentTarget.className += " active-tab";
Expand Down

0 comments on commit 8993219

Please sign in to comment.