diff --git a/tools/feature_importance.py b/tools/feature_importance.py
index 494c10a..ec098df 100644
--- a/tools/feature_importance.py
+++ b/tools/feature_importance.py
@@ -20,10 +20,10 @@ def __init__(self, task_type, output_dir, data_path=None, data=None, target_col
self.data = data
else:
self.target_col = target_col
- self.task_type = task_type
self.data = pd.read_csv(data_path, sep=None, engine='python')
self.data.columns = self.data.columns.str.replace('.', '_')
self.data = self.data.fillna(self.data.median(numeric_only=True))
+ self.task_type = task_type
self.target = self.data.columns[int(target_col) - 1]
self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
self.plots = {}
@@ -96,21 +96,26 @@ def generate_html_report(self, coef_html):
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
-
{plot_name.replace('_', ' ').capitalize()}
+
Feature importance analysis from a trained Random Forest
+
{'Use gini impurity for calculating feature importance for classification'
+ 'and Variance Reduction for regression'
+ if plot_name == 'tree_importance'
+ else 'SHAP Summary from a trained lightgbm'}
"""
# Generate HTML content with tabs
html_content = f"""
- PyCaret Feature Importance Report
+ PyCaret Feature Importance Report
-
-
Coefficients (based on a trained Logistic Regression Model)
-
{coef_html}
-
- {plots_html}
- """
+
+
Coefficients (based on a trained
+ {'Logistic Regression' if self.task_type == 'classification' else 'Linear Regression'} Model)
+
{coef_html}
+
+ {plots_html}
+ """
return html_content
diff --git a/tools/pycaret_train.xml b/tools/pycaret_train.xml
index 2a06c11..349a5a5 100644
--- a/tools/pycaret_train.xml
+++ b/tools/pycaret_train.xml
@@ -1,5 +1,5 @@
- Compare different machine learning models on a dataset using PyCaret.
+ Compare different machine learning models on a dataset using PyCaret. Do feature analysis using LR, Random Forest and LightGBM.
pycaret_macros.xml
@@ -174,7 +174,6 @@
This tool uses PyCaret to train and evaluate machine learning models.
- Ensure that the Conda environment specified in the requirements is correctly set up.
\ No newline at end of file