Skip to content

Commit

Permalink
model classes updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Sueda Ciftci committed Dec 5, 2024
1 parent 50abe4b commit 1478123
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 25 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# OS X
.DS_Store

# Data files
.pkl
.zip
.csv

# build
containers/cleanair/gpjax_models/build/*

# Terraform
.terraform
terraform/backend_config.tf
Expand Down
3 changes: 2 additions & 1 deletion containers/cleanair/gpjax_models/cli/urbanair_jax
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#! /usr/bin/env python
"""Thenserflow 2 GPflow Models CLI"""
"""GP JAX Models CLI"""
from curses import echo
import importlib
import typer
Expand All @@ -8,5 +8,6 @@ from gpjax_models.parser import dataset, model
app = typer.Typer()
app.add_typer(dataset.app, name="dataset")
app.add_typer(model.app, name="model")
# app.add_typer(visualization.app, name="visualization")
if __name__ == "__main__":
app()
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_X(df):
return np.array(
df[["epoch", "lat", "lon", "value_200_total_a_road_primary_length"]]
)
# return np.array(df[["epoch", "lat", "lon"]])


def get_X_trf(df):
Expand Down Expand Up @@ -275,7 +276,6 @@ def generate_data_laqn(train_data, test_data):
train_laqn_df = train_data["laqn"]
test_laqn_df = test_data["laqn"]
test_hexgrid_df = test_data["hexgrid"]

# Extract X and Y for training data
train_laqn_X, train_laqn_Y = process_data(train_laqn_df)

Expand Down Expand Up @@ -313,7 +313,7 @@ def generate_data_laqn(train_data, test_data):
"test": {"laqn": {"df": test_laqn_df}, "hexgrid": {"df": test_hexgrid_df}},
}

with open("raw_data_svgp.pkl", "wb") as file:
with open("raw_data_svgp_only_laqn_best.pkl", "wb") as file:
pickle.dump(meta_dict, file)

return train_dict, test_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
from scipy.cluster.vq import kmeans2
import objax
import jax
from jax.config import config as jax_config

jax_config.update("jax_enable_x64", True)
from tqdm import trange
from abc import abstractmethod
from .predicting.utils import batch_predict
Expand Down Expand Up @@ -201,14 +198,12 @@ def _reshape_pred(X):

# Save predictions
with open(
os.path.join(self.results_path, "predictions_mrdgp_7.pkl"), "wb"
os.path.join(self.results_path, "predictions_mrdgp.pkl"), "wb"
) as file:
pickle.dump(results, file)

# Save inducing points
with open(
os.path.join(self.results_path, "inducing_points_7.pkl"), "wb"
) as file:
with open(os.path.join(self.results_path, "inducing_points.pkl"), "wb") as file:
pickle.dump(inducing_points, file)

# Print model and inducing points
Expand Down
30 changes: 20 additions & 10 deletions containers/cleanair/gpjax_models/gpjax_models/models/stgp_svgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from jax.example_libraries import stax
from jax import random
from scipy.cluster.vq import kmeans2
from jax.config import config as jax_config

jax_config.update("jax_enable_x64", True)
import stgp
from stgp.models import GP
from stgp.kernels import ScaleKernel, RBF
Expand All @@ -26,6 +24,7 @@
class STGP_SVGP:
def __init__(
self,
results_path, # Non-default argument moved to the first position
M: int = 100,
batch_size: int = 100,
num_epochs: int = 10,
Expand All @@ -34,14 +33,19 @@ def __init__(
Initialize the JAX-based Air Quality Gaussian Process Model.
Args:
results_path (str): Path to the directory for saving results.
M (int): Number of inducing variables.
batch_size (int): Batch size for training.
num_epochs (int): Number of training epochs.
"""
self.results_path = results_path
self.M = M
self.batch_size = batch_size
self.num_epochs = num_epochs

# Ensure the results directory exists
os.makedirs(self.results_path, exist_ok=True)

def fit(self, x_train: np.ndarray, y_train: np.ndarray, pred_data) -> list[float]:
"""
Fit the model to training data.
Expand All @@ -58,7 +62,6 @@ def get_laqn_svgp(X_laqn, Y_laqn):
N, D = X_laqn.shape

data = Data(X_laqn, Y_laqn)
# data = TransformedData(data, [Log()])

Z = stgp.sparsity.FullSparsity(Z=kmeans2(X_laqn, 200, minit="points")[0])

Expand All @@ -67,7 +70,7 @@ def get_laqn_svgp(X_laqn, Y_laqn):
kernel=ScaleKernel(
RBF(
input_dim=D,
lengthscales=[0.1, 0.1, 0.1, 0.1],
lengthscales=[0.1] * D,
),
variance=np.nanstd(Y_laqn),
),
Expand Down Expand Up @@ -97,10 +100,9 @@ def train_laqn(num_epoch, m_laqn):
joint_grad = GradDescentTrainer(m_laqn, objax.optimizer.Adam)

lc_arr = []
num_epochs = num_epoch
laqn_natgrad.train(1.0, 1)

for i in trange(num_epochs):
for i in trange(num_epoch):
lc_i, _ = joint_grad.train(0.01, 1)
lc_arr.append(lc_i)

Expand Down Expand Up @@ -140,12 +142,20 @@ def pred_wrapper(XS):
results = predict_laqn_svgp(pred_data, m)

print(results["metrics"])
# Save the loss values to a pickle file
with open("loss_values_svgp_laqn.pkl", "wb") as file:
pickle.dump(loss_values, file)
with open("predictions_svgp_laqn.pkl", "wb") as file:
# Save predictions
with open(
os.path.join(self.results_path, "predictions_svgp_laqn__.pkl"), "wb"
) as file:
pickle.dump(results, file)

# Save inducing points
inducing_points = m.prior[0].sparsity.inducing_locations
with open(
os.path.join(self.results_path, "inducing_points_svgp_laqn__.pkl"), "wb"
) as file:
pickle.dump(inducing_points, file)
return loss_values


class STGP_SVGP_SAT:
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def train_svgp_sat(
batch_size (int): Batch size for training.
num_epochs (int): Number of training epochs.
"""
model = STGP_SVGP_SAT(M, batch_size, num_epochs, random_seed=42)
model = STGP_SVGP_SAT(M, batch_size, num_epochs)

typer.echo("Loading training data!")
# Iterate over the directories and subdirectories
Expand Down Expand Up @@ -138,7 +138,7 @@ def train_svgp_sat(

@app.command()
def train_svgp_laqn(
root_dir: str,
root_dir: Path,
M: int = 500,
batch_size: int = 200,
num_epochs: int = 1000,
Expand Down Expand Up @@ -209,8 +209,8 @@ def train_mrdgp(
root_dir: Path,
M: Optional[int] = 500,
batch_size: Optional[int] = 200,
num_epochs: Optional[int] = 100,
pretrain_epochs: Optional[int] = 100,
num_epochs: Optional[int] = 1500,
pretrain_epochs: Optional[int] = 1500,
):
"""
Train the SVGP_GPF2 model on the given training data.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Parser for GPflow 2 model fitting."""
"""Parser for JAX model fitting."""

import typer
from . import fit_cli
Expand Down

0 comments on commit 1478123

Please sign in to comment.