diff --git a/tools/base_model_trainer.py b/tools/base_model_trainer.py index c530237..9a93530 100644 --- a/tools/base_model_trainer.py +++ b/tools/base_model_trainer.py @@ -4,10 +4,10 @@ from feature_importance import FeatureImportanceAnalyzer -from utils import get_html_template, get_html_closing - import pandas as pd +from utils import get_html_closing, get_html_template + logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger(__name__) @@ -113,12 +113,22 @@ def save_html_report(self): model_name = type(self.best_model).__name__ excluded_params = ['html', 'log_experiment', 'system_log'] - filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params} - setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value']) - - best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value']) - best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False) - self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv")) + filtered_setup_params = { + k: v + for k, v in self.setup_params.items() if k not in excluded_params + } + setup_params_table = pd.DataFrame( + list(filtered_setup_params.items()), + columns=['Parameter', 'Value']) + + best_model_params = pd.DataFrame( + self.best_model.get_params().items(), + columns=['Parameter', 'Value']) + best_model_params.to_csv( + os.path.join(self.output_dir, 'best_model.csv'), + index=False) + self.results.to_csv(os.path.join( + self.output_dir, "comparison_results.csv")) plots_html = "" for plot_name, plot_path in self.plots.items(): @@ -126,31 +136,41 @@ def save_html_report(self): plots_html += f"""

{plot_name.capitalize()}

- {plot_name} + {plot_name}
""" - analyzer = FeatureImportanceAnalyzer(data=self.data, target_col=self.target_col, task_type='classification', output_dir=self.output_dir) + analyzer = FeatureImportanceAnalyzer( + data=self.data, + target_col=self.target_col, + task_type='classification', + output_dir=self.output_dir) feature_importance_html = analyzer.run() html_content = f""" {get_html_template()}

PyCaret Model Training Report

-
Setup & Best Model
-
Best Model Plots
-
Feature Importance
+
+ Setup & Best Model
+
+ Best Model Plots
+
+ Feature Importance

Setup Parameters

- {setup_params_table.to_html(index=False, header=False, classes='table')} + {setup_params_table.to_html( + index=False, header=False, classes='table')}
ParameterValue

Best Model: {model_name}

- {best_model_params.to_html(index=False, header=False, classes='table')} + {best_model_params.to_html( + index=False, header=False, classes='table')}
ParameterValue

Comparison Results

@@ -167,7 +187,8 @@ def save_html_report(self): {get_html_closing()} """ - with open(os.path.join(self.output_dir, "comparison_result.html"), "w") as file: + with open(os.path.join( + self.output_dir, "comparison_result.html"), "w") as file: file.write(html_content) def save_dashboard(self): diff --git a/tools/feature_importance.py b/tools/feature_importance.py index ec098df..c0aec03 100644 --- a/tools/feature_importance.py +++ b/tools/feature_importance.py @@ -2,12 +2,11 @@ import logging import os -import pandas as pd - import matplotlib.pyplot as plt -from pycaret.classification import ClassificationExperiment +import pandas as pd +from pycaret.classification import ClassificationExperiment from pycaret.regression import RegressionExperiment logging.basicConfig(level=logging.DEBUG) @@ -15,7 +14,14 @@ class FeatureImportanceAnalyzer: - def __init__(self, task_type, output_dir, data_path=None, data=None, target_col=None): + def __init__( + self, + task_type, + output_dir, + data_path=None, + data=None, + target_col=None): + if data is not None: self.data = data else: @@ -25,7 +31,9 @@ def __init__(self, task_type, output_dir, data_path=None, data=None, target_col self.data = self.data.fillna(self.data.median(numeric_only=True)) self.task_type = task_type self.target = self.data.columns[int(target_col) - 1] - self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() + self.exp = ClassificationExperiment() \ + if task_type == 'classification' \ + else RegressionExperiment() self.plots = {} self.output_dir = output_dir @@ -57,10 +65,14 @@ def save_tree_importance(self): 'Importance': importances }).sort_values(by='Importance', ascending=False) plt.figure(figsize=(10, 6)) - plt.barh(feature_importances['Feature'], feature_importances['Importance']) + plt.barh( + feature_importances['Feature'], + feature_importances['Importance']) plt.xlabel('Importance') plt.title('Feature Importance (Random Forest)') - plot_path = os.path.join(self.output_dir, 'tree_importance.png') + plot_path = os.path.join( + self.output_dir, + 'tree_importance.png') plt.savefig(plot_path) plt.close() self.plots['tree_importance'] = plot_path @@ -69,10 +81,13 @@ def save_shap_values(self): model = self.exp.create_model('lightgbm') import shap explainer = shap.Explainer(model) - shap_values = explainer.shap_values(self.data.drop(columns=[self.target])) - shap.summary_plot(shap_values, self.data.drop(columns=[self.target]), show=False) + shap_values = explainer.shap_values( + self.data.drop(columns=[self.target])) + shap.summary_plot(shap_values, self.data.drop( + columns=[self.target]), show=False) plt.title('Shap (LightGBM)') - plot_path = os.path.join(self.output_dir, 'shap_summary.png') + plot_path = os.path.join( + self.output_dir, 'shap_summary.png') plt.savefig(plot_path) plt.close() self.plots['shap_summary'] = plot_path @@ -86,7 +101,7 @@ def generate_feature_importance(self): def encode_image_to_base64(self, img_path): with open(img_path, 'rb') as img_file: return base64.b64encode(img_file.read()).decode('utf-8') - + def generate_html_report(self, coef_html): LOG.info("Generating HTML report") @@ -96,12 +111,15 @@ def generate_html_report(self, coef_html): encoded_image = self.encode_image_to_base64(plot_path) plots_html += f"""
-

Feature importance analysis from a trained Random Forest

-

{'Use gini impurity for calculating feature importance for classification' - 'and Variance Reduction for regression' - if plot_name == 'tree_importance' +

Feature importance analysis from a + trained Random Forest

+

{'Use gini impurity for' + 'calculating feature importance for classification' + 'and Variance Reduction for regression' + if plot_name == 'tree_importance' else 'SHAP Summary from a trained lightgbm'}

- {plot_name} + {plot_name}
""" @@ -110,8 +128,10 @@ def generate_html_report(self, coef_html):

PyCaret Feature Importance Report

-

Coefficients (based on a trained - {'Logistic Regression' if self.task_type == 'classification' else 'Linear Regression'} Model)

+

Coefficients (based on a trained + {'Logistic Regression' + if self.task_type == 'classification' + else 'Linear Regression'} Model)

{coef_html}
{plots_html} @@ -127,14 +147,26 @@ def run(self): LOG.info("Feature importance analysis completed") return html_content + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Feature Importance Analysis") - parser.add_argument("--data_path", type=str, help="Path to the dataset") - parser.add_argument("--target_col", type=int, help="Index of the target column (1-based)") - parser.add_argument("--task_type", type=str, choices=["classification", "regression"], help="Task type: classification or regression") - parser.add_argument("--output_dir", type=str, help="Directory to save the outputs") + parser.add_argument( + "--data_path", type=str, help="Path to the dataset") + parser.add_argument( + "--target_col", type=int, + help="Index of the target column (1-based)") + parser.add_argument( + "--task_type", type=str, + choices=["classification", "regression"], + help="Task type: classification or regression") + parser.add_argument( + "--output_dir", + type=str, + help="Directory to save the outputs") args = parser.parse_args() - analyzer = FeatureImportanceAnalyzer(args.data_path, args.target_col, args.task_type, args.output_dir) + analyzer = FeatureImportanceAnalyzer( + args.data_path, args.target_col, + args.task_type, args.output_dir) analyzer.run() diff --git a/tools/utils.py b/tools/utils.py index 45acf42..e4ae7f7 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -84,6 +84,7 @@ def get_html_template():
""" + def get_html_closing(): return """
@@ -96,7 +97,8 @@ def get_html_closing(): }} tablinks = document.getElementsByClassName("tab"); for (i = 0; i < tablinks.length; i++) {{ - tablinks[i].className = tablinks[i].className.replace(" active-tab", ""); + tablinks[i].className = + tablinks[i].className.replace(" active-tab", ""); }} document.getElementById(tabName).style.display = "block"; evt.currentTarget.className += " active-tab";