From 14781234d742cef48b317bfa9d2bf41b1b0ebecd Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 5 Dec 2024 13:06:06 +0000 Subject: [PATCH] model classes updates --- .gitignore | 8 +++++ .../cleanair/gpjax_models/cli/urbanair_jax | 3 +- .../gpjax_models/data/setup_data.py | 4 +-- .../gpjax_models/models/stgp_mrdgp.py | 9 ++---- .../gpjax_models/models/stgp_svgp.py | 30 ++++++++++++------- .../gpjax_models/parser/model/fit_cli.py | 8 ++--- .../gpjax_models/parser/model/main.py | 2 +- 7 files changed, 39 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index f526a48b1..52ed7b7fb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,14 @@ # OS X .DS_Store +# Data files +.pkl +.zip +.csv + +# build +containers/cleanair/gpjax_models/build/* + # Terraform .terraform terraform/backend_config.tf diff --git a/containers/cleanair/gpjax_models/cli/urbanair_jax b/containers/cleanair/gpjax_models/cli/urbanair_jax index cc619fc2c..82633e37a 100644 --- a/containers/cleanair/gpjax_models/cli/urbanair_jax +++ b/containers/cleanair/gpjax_models/cli/urbanair_jax @@ -1,5 +1,5 @@ #! /usr/bin/env python -"""Thenserflow 2 GPflow Models CLI""" +"""GP JAX Models CLI""" from curses import echo import importlib import typer @@ -8,5 +8,6 @@ from gpjax_models.parser import dataset, model app = typer.Typer() app.add_typer(dataset.app, name="dataset") app.add_typer(model.app, name="model") +# app.add_typer(visualization.app, name="visualization") if __name__ == "__main__": app() diff --git a/containers/cleanair/gpjax_models/gpjax_models/data/setup_data.py b/containers/cleanair/gpjax_models/gpjax_models/data/setup_data.py index 1f858cda5..40d24ee0a 100644 --- a/containers/cleanair/gpjax_models/gpjax_models/data/setup_data.py +++ b/containers/cleanair/gpjax_models/gpjax_models/data/setup_data.py @@ -18,6 +18,7 @@ def get_X(df): return np.array( df[["epoch", "lat", "lon", "value_200_total_a_road_primary_length"]] ) + # return np.array(df[["epoch", "lat", "lon"]]) def get_X_trf(df): @@ -275,7 +276,6 @@ def generate_data_laqn(train_data, test_data): train_laqn_df = train_data["laqn"] test_laqn_df = test_data["laqn"] test_hexgrid_df = test_data["hexgrid"] - # Extract X and Y for training data train_laqn_X, train_laqn_Y = process_data(train_laqn_df) @@ -313,7 +313,7 @@ def generate_data_laqn(train_data, test_data): "test": {"laqn": {"df": test_laqn_df}, "hexgrid": {"df": test_hexgrid_df}}, } - with open("raw_data_svgp.pkl", "wb") as file: + with open("raw_data_svgp_only_laqn_best.pkl", "wb") as file: pickle.dump(meta_dict, file) return train_dict, test_dict diff --git a/containers/cleanair/gpjax_models/gpjax_models/models/stgp_mrdgp.py b/containers/cleanair/gpjax_models/gpjax_models/models/stgp_mrdgp.py index fb0cd4475..a2246a43c 100644 --- a/containers/cleanair/gpjax_models/gpjax_models/models/stgp_mrdgp.py +++ b/containers/cleanair/gpjax_models/gpjax_models/models/stgp_mrdgp.py @@ -5,9 +5,6 @@ from scipy.cluster.vq import kmeans2 import objax import jax -from jax.config import config as jax_config - -jax_config.update("jax_enable_x64", True) from tqdm import trange from abc import abstractmethod from .predicting.utils import batch_predict @@ -201,14 +198,12 @@ def _reshape_pred(X): # Save predictions with open( - os.path.join(self.results_path, "predictions_mrdgp_7.pkl"), "wb" + os.path.join(self.results_path, "predictions_mrdgp.pkl"), "wb" ) as file: pickle.dump(results, file) # Save inducing points - with open( - os.path.join(self.results_path, "inducing_points_7.pkl"), "wb" - ) as file: + with open(os.path.join(self.results_path, "inducing_points.pkl"), "wb") as file: pickle.dump(inducing_points, file) # Print model and inducing points diff --git a/containers/cleanair/gpjax_models/gpjax_models/models/stgp_svgp.py b/containers/cleanair/gpjax_models/gpjax_models/models/stgp_svgp.py index 8ee7d80b3..ccfc9ccdf 100644 --- a/containers/cleanair/gpjax_models/gpjax_models/models/stgp_svgp.py +++ b/containers/cleanair/gpjax_models/gpjax_models/models/stgp_svgp.py @@ -7,9 +7,7 @@ from jax.example_libraries import stax from jax import random from scipy.cluster.vq import kmeans2 -from jax.config import config as jax_config -jax_config.update("jax_enable_x64", True) import stgp from stgp.models import GP from stgp.kernels import ScaleKernel, RBF @@ -26,6 +24,7 @@ class STGP_SVGP: def __init__( self, + results_path, # Non-default argument moved to the first position M: int = 100, batch_size: int = 100, num_epochs: int = 10, @@ -34,14 +33,19 @@ def __init__( Initialize the JAX-based Air Quality Gaussian Process Model. Args: + results_path (str): Path to the directory for saving results. M (int): Number of inducing variables. batch_size (int): Batch size for training. num_epochs (int): Number of training epochs. """ + self.results_path = results_path self.M = M self.batch_size = batch_size self.num_epochs = num_epochs + # Ensure the results directory exists + os.makedirs(self.results_path, exist_ok=True) + def fit(self, x_train: np.ndarray, y_train: np.ndarray, pred_data) -> list[float]: """ Fit the model to training data. @@ -58,7 +62,6 @@ def get_laqn_svgp(X_laqn, Y_laqn): N, D = X_laqn.shape data = Data(X_laqn, Y_laqn) - # data = TransformedData(data, [Log()]) Z = stgp.sparsity.FullSparsity(Z=kmeans2(X_laqn, 200, minit="points")[0]) @@ -67,7 +70,7 @@ def get_laqn_svgp(X_laqn, Y_laqn): kernel=ScaleKernel( RBF( input_dim=D, - lengthscales=[0.1, 0.1, 0.1, 0.1], + lengthscales=[0.1] * D, ), variance=np.nanstd(Y_laqn), ), @@ -97,10 +100,9 @@ def train_laqn(num_epoch, m_laqn): joint_grad = GradDescentTrainer(m_laqn, objax.optimizer.Adam) lc_arr = [] - num_epochs = num_epoch laqn_natgrad.train(1.0, 1) - for i in trange(num_epochs): + for i in trange(num_epoch): lc_i, _ = joint_grad.train(0.01, 1) lc_arr.append(lc_i) @@ -140,12 +142,20 @@ def pred_wrapper(XS): results = predict_laqn_svgp(pred_data, m) print(results["metrics"]) - # Save the loss values to a pickle file - with open("loss_values_svgp_laqn.pkl", "wb") as file: - pickle.dump(loss_values, file) - with open("predictions_svgp_laqn.pkl", "wb") as file: + # Save predictions + with open( + os.path.join(self.results_path, "predictions_svgp_laqn__.pkl"), "wb" + ) as file: pickle.dump(results, file) + # Save inducing points + inducing_points = m.prior[0].sparsity.inducing_locations + with open( + os.path.join(self.results_path, "inducing_points_svgp_laqn__.pkl"), "wb" + ) as file: + pickle.dump(inducing_points, file) + return loss_values + class STGP_SVGP_SAT: def __init__( diff --git a/containers/cleanair/gpjax_models/gpjax_models/parser/model/fit_cli.py b/containers/cleanair/gpjax_models/gpjax_models/parser/model/fit_cli.py index 73d6c12c3..cc15f8eba 100644 --- a/containers/cleanair/gpjax_models/gpjax_models/parser/model/fit_cli.py +++ b/containers/cleanair/gpjax_models/gpjax_models/parser/model/fit_cli.py @@ -84,7 +84,7 @@ def train_svgp_sat( batch_size (int): Batch size for training. num_epochs (int): Number of training epochs. """ - model = STGP_SVGP_SAT(M, batch_size, num_epochs, random_seed=42) + model = STGP_SVGP_SAT(M, batch_size, num_epochs) typer.echo("Loading training data!") # Iterate over the directories and subdirectories @@ -138,7 +138,7 @@ def train_svgp_sat( @app.command() def train_svgp_laqn( - root_dir: str, + root_dir: Path, M: int = 500, batch_size: int = 200, num_epochs: int = 1000, @@ -209,8 +209,8 @@ def train_mrdgp( root_dir: Path, M: Optional[int] = 500, batch_size: Optional[int] = 200, - num_epochs: Optional[int] = 100, - pretrain_epochs: Optional[int] = 100, + num_epochs: Optional[int] = 1500, + pretrain_epochs: Optional[int] = 1500, ): """ Train the SVGP_GPF2 model on the given training data. diff --git a/containers/cleanair/gpjax_models/gpjax_models/parser/model/main.py b/containers/cleanair/gpjax_models/gpjax_models/parser/model/main.py index f96655358..59ef97b88 100644 --- a/containers/cleanair/gpjax_models/gpjax_models/parser/model/main.py +++ b/containers/cleanair/gpjax_models/gpjax_models/parser/model/main.py @@ -1,4 +1,4 @@ -"""Parser for GPflow 2 model fitting.""" +"""Parser for JAX model fitting.""" import typer from . import fit_cli