diff --git a/tools/base_model_trainer.py b/tools/base_model_trainer.py
index c530237..9a93530 100644
--- a/tools/base_model_trainer.py
+++ b/tools/base_model_trainer.py
@@ -4,10 +4,10 @@
from feature_importance import FeatureImportanceAnalyzer
-from utils import get_html_template, get_html_closing
-
import pandas as pd
+from utils import get_html_closing, get_html_template
+
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)
@@ -113,12 +113,22 @@ def save_html_report(self):
model_name = type(self.best_model).__name__
excluded_params = ['html', 'log_experiment', 'system_log']
- filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params}
- setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value'])
-
- best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
- best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)
- self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
+ filtered_setup_params = {
+ k: v
+ for k, v in self.setup_params.items() if k not in excluded_params
+ }
+ setup_params_table = pd.DataFrame(
+ list(filtered_setup_params.items()),
+ columns=['Parameter', 'Value'])
+
+ best_model_params = pd.DataFrame(
+ self.best_model.get_params().items(),
+ columns=['Parameter', 'Value'])
+ best_model_params.to_csv(
+ os.path.join(self.output_dir, 'best_model.csv'),
+ index=False)
+ self.results.to_csv(os.path.join(
+ self.output_dir, "comparison_results.csv"))
plots_html = ""
for plot_name, plot_path in self.plots.items():
@@ -126,31 +136,41 @@ def save_html_report(self):
plots_html += f"""
{plot_name.capitalize()}
-
+
"""
- analyzer = FeatureImportanceAnalyzer(data=self.data, target_col=self.target_col, task_type='classification', output_dir=self.output_dir)
+ analyzer = FeatureImportanceAnalyzer(
+ data=self.data,
+ target_col=self.target_col,
+ task_type='classification',
+ output_dir=self.output_dir)
feature_importance_html = analyzer.run()
html_content = f"""
{get_html_template()}
PyCaret Model Training Report
-
Setup & Best Model
-
Best Model Plots
-
Feature Importance
+
+ Setup & Best Model
+
+ Best Model Plots
+
+ Feature Importance
Setup Parameters
Parameter | Value |
- {setup_params_table.to_html(index=False, header=False, classes='table')}
+ {setup_params_table.to_html(
+ index=False, header=False, classes='table')}
Best Model: {model_name}
Parameter | Value |
- {best_model_params.to_html(index=False, header=False, classes='table')}
+ {best_model_params.to_html(
+ index=False, header=False, classes='table')}
Comparison Results
@@ -167,7 +187,8 @@ def save_html_report(self):
{get_html_closing()}
"""
- with open(os.path.join(self.output_dir, "comparison_result.html"), "w") as file:
+ with open(os.path.join(
+ self.output_dir, "comparison_result.html"), "w") as file:
file.write(html_content)
def save_dashboard(self):
diff --git a/tools/feature_importance.py b/tools/feature_importance.py
index ec098df..c0aec03 100644
--- a/tools/feature_importance.py
+++ b/tools/feature_importance.py
@@ -2,12 +2,11 @@
import logging
import os
-import pandas as pd
-
import matplotlib.pyplot as plt
-from pycaret.classification import ClassificationExperiment
+import pandas as pd
+from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment
logging.basicConfig(level=logging.DEBUG)
@@ -15,7 +14,14 @@
class FeatureImportanceAnalyzer:
- def __init__(self, task_type, output_dir, data_path=None, data=None, target_col=None):
+ def __init__(
+ self,
+ task_type,
+ output_dir,
+ data_path=None,
+ data=None,
+ target_col=None):
+
if data is not None:
self.data = data
else:
@@ -25,7 +31,9 @@ def __init__(self, task_type, output_dir, data_path=None, data=None, target_col
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.task_type = task_type
self.target = self.data.columns[int(target_col) - 1]
- self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
+ self.exp = ClassificationExperiment() \
+ if task_type == 'classification' \
+ else RegressionExperiment()
self.plots = {}
self.output_dir = output_dir
@@ -57,10 +65,14 @@ def save_tree_importance(self):
'Importance': importances
}).sort_values(by='Importance', ascending=False)
plt.figure(figsize=(10, 6))
- plt.barh(feature_importances['Feature'], feature_importances['Importance'])
+ plt.barh(
+ feature_importances['Feature'],
+ feature_importances['Importance'])
plt.xlabel('Importance')
plt.title('Feature Importance (Random Forest)')
- plot_path = os.path.join(self.output_dir, 'tree_importance.png')
+ plot_path = os.path.join(
+ self.output_dir,
+ 'tree_importance.png')
plt.savefig(plot_path)
plt.close()
self.plots['tree_importance'] = plot_path
@@ -69,10 +81,13 @@ def save_shap_values(self):
model = self.exp.create_model('lightgbm')
import shap
explainer = shap.Explainer(model)
- shap_values = explainer.shap_values(self.data.drop(columns=[self.target]))
- shap.summary_plot(shap_values, self.data.drop(columns=[self.target]), show=False)
+ shap_values = explainer.shap_values(
+ self.data.drop(columns=[self.target]))
+ shap.summary_plot(shap_values, self.data.drop(
+ columns=[self.target]), show=False)
plt.title('Shap (LightGBM)')
- plot_path = os.path.join(self.output_dir, 'shap_summary.png')
+ plot_path = os.path.join(
+ self.output_dir, 'shap_summary.png')
plt.savefig(plot_path)
plt.close()
self.plots['shap_summary'] = plot_path
@@ -86,7 +101,7 @@ def generate_feature_importance(self):
def encode_image_to_base64(self, img_path):
with open(img_path, 'rb') as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
-
+
def generate_html_report(self, coef_html):
LOG.info("Generating HTML report")
@@ -96,12 +111,15 @@ def generate_html_report(self, coef_html):
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
-
Feature importance analysis from a trained Random Forest
-
{'Use gini impurity for calculating feature importance for classification'
- 'and Variance Reduction for regression'
- if plot_name == 'tree_importance'
+ Feature importance analysis from a
+ trained Random Forest
+ {'Use gini impurity for'
+ 'calculating feature importance for classification'
+ 'and Variance Reduction for regression'
+ if plot_name == 'tree_importance'
else 'SHAP Summary from a trained lightgbm'}
-
+
"""
@@ -110,8 +128,10 @@ def generate_html_report(self, coef_html):
PyCaret Feature Importance Report
-
Coefficients (based on a trained
- {'Logistic Regression' if self.task_type == 'classification' else 'Linear Regression'} Model)
+
Coefficients (based on a trained
+ {'Logistic Regression'
+ if self.task_type == 'classification'
+ else 'Linear Regression'} Model)
{coef_html}
{plots_html}
@@ -127,14 +147,26 @@ def run(self):
LOG.info("Feature importance analysis completed")
return html_content
+
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Feature Importance Analysis")
- parser.add_argument("--data_path", type=str, help="Path to the dataset")
- parser.add_argument("--target_col", type=int, help="Index of the target column (1-based)")
- parser.add_argument("--task_type", type=str, choices=["classification", "regression"], help="Task type: classification or regression")
- parser.add_argument("--output_dir", type=str, help="Directory to save the outputs")
+ parser.add_argument(
+ "--data_path", type=str, help="Path to the dataset")
+ parser.add_argument(
+ "--target_col", type=int,
+ help="Index of the target column (1-based)")
+ parser.add_argument(
+ "--task_type", type=str,
+ choices=["classification", "regression"],
+ help="Task type: classification or regression")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ help="Directory to save the outputs")
args = parser.parse_args()
- analyzer = FeatureImportanceAnalyzer(args.data_path, args.target_col, args.task_type, args.output_dir)
+ analyzer = FeatureImportanceAnalyzer(
+ args.data_path, args.target_col,
+ args.task_type, args.output_dir)
analyzer.run()
diff --git a/tools/utils.py b/tools/utils.py
index 45acf42..e4ae7f7 100644
--- a/tools/utils.py
+++ b/tools/utils.py
@@ -84,6 +84,7 @@ def get_html_template():
"""
+
def get_html_closing():
return """
@@ -96,7 +97,8 @@ def get_html_closing():
}}
tablinks = document.getElementsByClassName("tab");
for (i = 0; i < tablinks.length; i++) {{
- tablinks[i].className = tablinks[i].className.replace(" active-tab", "");
+ tablinks[i].className =
+ tablinks[i].className.replace(" active-tab", "");
}}
document.getElementById(tabName).style.display = "block";
evt.currentTarget.className += " active-tab";