Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding actions for PR #16

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
402 changes: 402 additions & 0 deletions .github/workflows/pr.yaml

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
.env
env
*.csv
*.tsv
*.log
*.png
*.txt
*.pkl
*.html
*.ipynb
__pycache__
__pycache__
.DS_Store
Binary file added galaxy-master.tar.gz
Binary file not shown.
8 changes: 8 additions & 0 deletions tools/.shed.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
categories:
- Machine Learning
description: Tools for machine learning with pycaret in a simple, powerful and robust way
name: galaxy-pycaret
owner: goeckslab
long_description: Tools for machine learning with pycaret in a simple, powerful and robust way
remote_repository_url: https://github.com/goeckslab/Galaxy-Pycaret
homepage_url: https://github.com/goeckslab/Galaxy-Pycaret
79 changes: 52 additions & 27 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import sys
import pandas as pd
import os
import logging
import base64
import logging
import os

import pandas as pd

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:

def __init__(self, input_file, target_col, output_dir, **kwargs):
self.exp = None # This will be set in the subclass
self.input_file = input_file
Expand All @@ -21,7 +23,7 @@ def __init__(self, input_file, target_col, output_dir, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.setup_params = {}

LOG.info(f"Model kwargs: {self.__dict__}")

def load_data(self):
Expand All @@ -48,26 +50,35 @@ def setup_pycaret(self):
if hasattr(self, 'normalize') and self.normalize is not None:
self.setup_params['normalize'] = self.normalize

if hasattr(self, 'feature_selection') and self.feature_selection is not None:
if hasattr(self, 'feature_selection') and \
self.feature_selection is not None:
self.setup_params['feature_selection'] = self.feature_selection

if hasattr(self, 'cross_validation') and self.cross_validation is not None and self.cross_validation == False:
if hasattr(self, 'cross_validation') and \
self.cross_validation is not None \
and self.cross_validation is False:
self.setup_params['cross_validation'] = self.cross_validation

if hasattr(self, 'cross_validation') and self.cross_validation is not None:
if hasattr(self, 'cross_validation') and \
self.cross_validation is not None:
if hasattr(self, 'cross_validation_folds'):
self.setup_params['fold'] = self.cross_validation_folds

if hasattr(self, 'remove_outliers') and self.remove_outliers is not None:
if hasattr(self, 'remove_outliers') and \
self.remove_outliers is not None:
self.setup_params['remove_outliers'] = self.remove_outliers

if hasattr(self, 'remove_multicollinearity') and self.remove_multicollinearity is not None:
self.setup_params['remove_multicollinearity'] = self.remove_multicollinearity
if hasattr(self, 'remove_multicollinearity') and \
self.remove_multicollinearity is not None:
self.setup_params['remove_multicollinearity'] = \
self.remove_multicollinearity

if hasattr(self, 'polynomial_features') and self.polynomial_features is not None:
if hasattr(self, 'polynomial_features') and \
self.polynomial_features is not None:
self.setup_params['polynomial_features'] = self.polynomial_features

if hasattr(self, 'fix_imbalance') and self.fix_imbalance is not None:
if hasattr(self, 'fix_imbalance') and \
self.fix_imbalance is not None:
self.setup_params['fix_imbalance'] = self.fix_imbalance

LOG.info(self.setup_params)
Expand All @@ -93,17 +104,25 @@ def save_html_report(self):
LOG.info("Saving HTML report")

model_name = type(self.best_model).__name__

excluded_params = ['html', 'log_experiment', 'system_log']
filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params}
setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value'])

filtered_setup_params = {
k: v
for k, v in self.setup_params.items() if k not in excluded_params
}
setup_params_table = pd.DataFrame(
list(filtered_setup_params.items()),
columns=['Parameter', 'Value'])
# Save model summary
best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)
best_model_params = pd.DataFrame(
self.best_model.get_params().items(),
columns=['Parameter', 'Value'])
best_model_params.to_csv(
os.path.join(self.output_dir, 'best_model.csv'),
index=False)

# Save comparison results
self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
self.results.to_csv(os.path.join(
self.output_dir, "comparison_results.csv"))

# Read and encode plot images
plots_html = ""
Expand All @@ -112,7 +131,8 @@ def save_html_report(self):
plots_html += f"""
<div class="plot">
<h3>{plot_name.capitalize()}</h3>
<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
</div>
"""

Expand All @@ -122,7 +142,8 @@ def save_html_report(self):
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="viewport" content="width=device-width,
initial-scale=1.0">
<title>PyCaret Model Training Report</title>
<style>
body {{
Expand Down Expand Up @@ -179,16 +200,19 @@ def save_html_report(self):
<h2>Setup Parameters</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{setup_params_table.to_html(index=False, header=False, classes='table')}
{setup_params_table.to_html(index=False,
header=False, classes='table')}
</table>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{best_model_params.to_html(index=False, header=False, classes='table')}
{best_model_params.to_html(index=False,
header=False, classes='table')}
</table>
<h2>Comparison Results</h2>
<table>
{self.results.to_html(index=False, classes='table')}
{self.results.to_html(index=False,
classes='table')}
</table>
<h2>Plots</h2>
{plots_html}
Expand All @@ -197,12 +221,13 @@ def save_html_report(self):
</html>
"""

with open(os.path.join(self.output_dir, "comparison_result.html"), "w") as file:
with open(os.path.join(
self.output_dir, "comparison_result.html"), "w") as file:
file.write(html_content)

def save_dashboard(self):
raise NotImplementedError("Subclasses should implement this method")

def run(self):
self.load_data()
self.setup_pycaret()
Expand Down
Loading