Skip to content

Commit

Permalink
Adding actions for PR (goeckslab#16)
Browse files Browse the repository at this point in the history
* updates for pr actions

1. formart for planemo test and lint, and flask8 lint
2. found and fixed bugs during planemo test
3. added test-data files

* Create .shed.yml

* add details to .shed.yml

* importing order

* importing order

* updated pr.yaml and import order

1 for pr.yaml: commend checking file-size for now
2. import order again
  • Loading branch information
qchiujunhao authored Jul 16, 2024
1 parent 5696cf1 commit 2207ffd
Show file tree
Hide file tree
Showing 19 changed files with 2,774 additions and 135 deletions.
402 changes: 402 additions & 0 deletions .github/workflows/pr.yaml

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
.env
env
*.csv
*.tsv
*.log
*.png
*.txt
*.pkl
*.html
*.ipynb
__pycache__
__pycache__
.DS_Store
Binary file added galaxy-master.tar.gz
Binary file not shown.
8 changes: 8 additions & 0 deletions tools/.shed.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
categories:
- Machine Learning
description: Tools for machine learning with pycaret in a simple, powerful and robust way
name: galaxy-pycaret
owner: goeckslab
long_description: Tools for machine learning with pycaret in a simple, powerful and robust way
remote_repository_url: https://github.com/goeckslab/Galaxy-Pycaret
homepage_url: https://github.com/goeckslab/Galaxy-Pycaret
79 changes: 52 additions & 27 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import sys
import pandas as pd
import os
import logging
import base64
import logging
import os

import pandas as pd

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:

def __init__(self, input_file, target_col, output_dir, **kwargs):
self.exp = None # This will be set in the subclass
self.input_file = input_file
Expand All @@ -21,7 +23,7 @@ def __init__(self, input_file, target_col, output_dir, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.setup_params = {}

LOG.info(f"Model kwargs: {self.__dict__}")

def load_data(self):
Expand All @@ -48,26 +50,35 @@ def setup_pycaret(self):
if hasattr(self, 'normalize') and self.normalize is not None:
self.setup_params['normalize'] = self.normalize

if hasattr(self, 'feature_selection') and self.feature_selection is not None:
if hasattr(self, 'feature_selection') and \
self.feature_selection is not None:
self.setup_params['feature_selection'] = self.feature_selection

if hasattr(self, 'cross_validation') and self.cross_validation is not None and self.cross_validation == False:
if hasattr(self, 'cross_validation') and \
self.cross_validation is not None \
and self.cross_validation is False:
self.setup_params['cross_validation'] = self.cross_validation

if hasattr(self, 'cross_validation') and self.cross_validation is not None:
if hasattr(self, 'cross_validation') and \
self.cross_validation is not None:
if hasattr(self, 'cross_validation_folds'):
self.setup_params['fold'] = self.cross_validation_folds

if hasattr(self, 'remove_outliers') and self.remove_outliers is not None:
if hasattr(self, 'remove_outliers') and \
self.remove_outliers is not None:
self.setup_params['remove_outliers'] = self.remove_outliers

if hasattr(self, 'remove_multicollinearity') and self.remove_multicollinearity is not None:
self.setup_params['remove_multicollinearity'] = self.remove_multicollinearity
if hasattr(self, 'remove_multicollinearity') and \
self.remove_multicollinearity is not None:
self.setup_params['remove_multicollinearity'] = \
self.remove_multicollinearity

if hasattr(self, 'polynomial_features') and self.polynomial_features is not None:
if hasattr(self, 'polynomial_features') and \
self.polynomial_features is not None:
self.setup_params['polynomial_features'] = self.polynomial_features

if hasattr(self, 'fix_imbalance') and self.fix_imbalance is not None:
if hasattr(self, 'fix_imbalance') and \
self.fix_imbalance is not None:
self.setup_params['fix_imbalance'] = self.fix_imbalance

LOG.info(self.setup_params)
Expand All @@ -93,17 +104,25 @@ def save_html_report(self):
LOG.info("Saving HTML report")

model_name = type(self.best_model).__name__

excluded_params = ['html', 'log_experiment', 'system_log']
filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params}
setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value'])

filtered_setup_params = {
k: v
for k, v in self.setup_params.items() if k not in excluded_params
}
setup_params_table = pd.DataFrame(
list(filtered_setup_params.items()),
columns=['Parameter', 'Value'])
# Save model summary
best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)
best_model_params = pd.DataFrame(
self.best_model.get_params().items(),
columns=['Parameter', 'Value'])
best_model_params.to_csv(
os.path.join(self.output_dir, 'best_model.csv'),
index=False)

# Save comparison results
self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
self.results.to_csv(os.path.join(
self.output_dir, "comparison_results.csv"))

# Read and encode plot images
plots_html = ""
Expand All @@ -112,7 +131,8 @@ def save_html_report(self):
plots_html += f"""
<div class="plot">
<h3>{plot_name.capitalize()}</h3>
<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
</div>
"""

Expand All @@ -122,7 +142,8 @@ def save_html_report(self):
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="viewport" content="width=device-width,
initial-scale=1.0">
<title>PyCaret Model Training Report</title>
<style>
body {{
Expand Down Expand Up @@ -179,16 +200,19 @@ def save_html_report(self):
<h2>Setup Parameters</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{setup_params_table.to_html(index=False, header=False, classes='table')}
{setup_params_table.to_html(index=False,
header=False, classes='table')}
</table>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{best_model_params.to_html(index=False, header=False, classes='table')}
{best_model_params.to_html(index=False,
header=False, classes='table')}
</table>
<h2>Comparison Results</h2>
<table>
{self.results.to_html(index=False, classes='table')}
{self.results.to_html(index=False,
classes='table')}
</table>
<h2>Plots</h2>
{plots_html}
Expand All @@ -197,12 +221,13 @@ def save_html_report(self):
</html>
"""

with open(os.path.join(self.output_dir, "comparison_result.html"), "w") as file:
with open(os.path.join(
self.output_dir, "comparison_result.html"), "w") as file:
file.write(html_content)

def save_dashboard(self):
raise NotImplementedError("Subclasses should implement this method")

def run(self):
self.load_data()
self.setup_pycaret()
Expand Down
Loading

0 comments on commit 2207ffd

Please sign in to comment.