Skip to content

Commit

Permalink
update train mrdgp with file manager
Browse files Browse the repository at this point in the history
  • Loading branch information
Sueda Ciftci committed Jul 20, 2024
1 parent 602f282 commit 1565aa9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import pickle
import pandas as pd
from typing import Optional
from pathlib import Path
import os
import jax
import optax
import numpy as np
import jax.numpy as jnp

from jax.config import config as jax_config

jax_config.update("jax_enable_x64", True)
from jax import config

# Set a configuration option if needed (example)
config.update("jax_enable_x64", True)
from ...utils.file_manager import FileManager
from ...models.svgp import SVGP
from ...models.stgp_svgp import STGP_SVGP_SAT, STGP_SVGP
from ...models.stgp_mrdgp import STGP_MRDGP
Expand All @@ -25,12 +27,6 @@

app = typer.Typer()

# Defining blob storage and other constants here
RESOURCE_GROUP = "Datasets"
STORAGE_CONTAINER_NAME = "aqdata"
STORAGE_ACCOUNT_NAME = "londonaqdatasets"
ACCOUNT_URL = "https://londonaqdatasets.blob.core.windows.net/"


@app.command()
def svgp(
Expand Down Expand Up @@ -77,7 +73,7 @@ def train_svgp_sat(
root_dir: str,
M: int = 500,
batch_size: int = 200,
num_epochs: int = 2500,
num_epochs: int = 50,
):
"""
Train the SVGP_GPF2 model on the given training data.
Expand Down Expand Up @@ -135,6 +131,7 @@ def train_svgp_sat(
},
}
# Train the model
breakpoint()
model.fit(x_sat, y_sat, pred_laqn_data, pred_sat_data)
typer.echo("Training complete!")

Expand Down Expand Up @@ -209,7 +206,7 @@ def train_svgp_laqn(
# TODO make one train comand to reach out config to get the model name
@app.command()
def train_mrdgp(
root_dir: str,
root_dir: Path,
M: Optional[int] = 500,
batch_size: Optional[int] = 200,
num_epochs: Optional[int] = 100,
Expand All @@ -226,36 +223,25 @@ def train_mrdgp(
"""

model = STGP_MRDGP(M, batch_size, num_epochs, pretrain_epochs, root_dir)
file_manager = FileManager(root_dir)
# Load training data
typer.echo("Loading training data!")
# Iterate over the directories and subdirectories
for dirpath, _, filenames in os.walk(root_dir):
# Check if 'training_dataset.pkl' exists in the current directory
if "training_dataset.pkl" in filenames:
# If found, load the data
file_path = os.path.join(dirpath, "training_dataset.pkl")
with open(file_path, "rb") as file:
train_data = pickle.load(file)
training_data = file_manager.load_training_data()

typer.echo("Loading testing data!")
for dirpath, _, filenames in os.walk(root_dir):
# Check if 'training_dataset.pkl' exists in the current directory
if "test_dataset.pkl" in filenames:
# If found, load the data
file_path = os.path.join(dirpath, "test_dataset.pkl")
with open(file_path, "rb") as file:
test_data_dict = pickle.load(file)
test_data = file_manager.load_training_data()

train_dict, test_dict = generate_data(train_data, test_data_dict)
x_laqn = train_dict["laqn"]["X"]
y_laqn = train_dict["laqn"]["Y"]
x_sat = train_dict["sat"]["X"]
y_sat = train_dict["sat"]["Y"]
breakpoint()
training_dict, test_dict = generate_data(training_data, test_data)
x_laqn = training_dict["laqn"]["X"]
y_laqn = training_dict["laqn"]["Y"]
x_sat = training_dict["sat"]["X"]
y_sat = training_dict["sat"]["Y"]

pred_sat_data = {
"sat": {
"X": train_dict["sat"]["X"],
"Y": train_dict["sat"]["Y"],
"X": training_dict["sat"]["X"],
"Y": training_dict["sat"]["Y"],
},
}

Expand All @@ -268,9 +254,9 @@ def train_mrdgp(
"X": test_dict["laqn"]["X"],
"Y": None,
},
"train_laqn": {
"X": train_dict["laqn"]["X"],
"Y": train_dict["laqn"]["Y"],
"training_laqn": {
"X": training_dict["laqn"]["X"],
"Y": training_dict["laqn"]["Y"],
},
}
model.fit(x_sat, y_sat, x_laqn, y_laqn, pred_laqn_data, pred_sat_data)
Expand All @@ -282,21 +268,31 @@ def train_mrdgp_trf(
root_dir: str,
M: Optional[int] = 500,
batch_size: Optional[int] = 200,
num_epochs: Optional[int] = 100,
pretrain_epochs: Optional[int] = 100,
num_epochs: Optional[int] = 50,
pretrain_epochs: Optional[int] = 50,
random_seed: Optional[int] = 0,
):
"""
Train the SVGP_GPF2 model on the given training data.
Args:
train_file_path (str): Path to the training data pickle file.
root_dir (str): Root directory containing training and testing data.
M (int): Number of inducing variables.
batch_size (int): Batch size for training.
num_epochs (int): Number of training epochs.
pretrain_epochs (int): Number of pretraining epochs.
random_seed (int): Random seed for reproducibility.
"""

generator = np.random.default_rng(random_seed)
model = STGP_MRDGP(
M, batch_size, num_epochs, pretrain_epochs, root_dir, random_seed=42
M,
batch_size,
num_epochs,
pretrain_epochs,
root_dir,
jax_random_seed=random_seed,
generator=generator,
)
# Load training data
typer.echo("Loading training data!")
Expand All @@ -311,7 +307,7 @@ def train_mrdgp_trf(

typer.echo("Loading testing data!")
for dirpath, _, filenames in os.walk(root_dir):
# Check if 'training_dataset.pkl' exists in the current directory
# Check if 'test_dataset.pkl' exists in the current directory
if "test_dataset.pkl" in filenames:
# If found, load the data
file_path = os.path.join(dirpath, "test_dataset.pkl")
Expand Down
56 changes: 39 additions & 17 deletions containers/cleanair/gpjax_models/gpjax_models/utils/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
from pathlib import Path
import logging
import os


class ExperimentInstanceNotFoundError(Exception):
Expand All @@ -21,10 +22,10 @@ class FileManager:

# Constants
DEFAULT_TRAINING_NAME = "training"
DATASET = Path("datasets")
DATASET = Path("dataset")
RAW_DATA_PICKLE = DATASET / "raw_data.pkl"
TRAINING_DATA_PICKLE = DATASET / "training_data.pkl"
TEST_DATA_PICKLE = DATASET / "test_dataset.pkl"
TRAINING_DATA_PICKLE = "training_data.pkl"
TEST_DATA_PICKLE = "test_dataset.pkl"
RESOURCE_GROUP = "Datasets"
STORAGE_CONTAINER_NAME = "aqdata"
STORAGE_ACCOUNT_NAME = "londonaqdatasets"
Expand Down Expand Up @@ -65,9 +66,41 @@ def download_data_blob(cls, name: str = None, input_dir: Path = None) -> None:
cls.logger.error(f"An error occurred during download: {str(e)}")

def load_training_data(self) -> dict:
self.download_data_blob(name="training", input_dir=Path.cwd())
pickle_path = Path.cwd() / FileManager.TRAINING_DATA_PICKLE
return self.load_pickle(pickle_path)
"""Load training data from the dataset directory."""
for dirpath, _, filenames in os.walk(self.input_dir):
# Check if 'training_dataset.pkl' exists in the current directory
if "test_dataset.pkl" in filenames:
# If found, load the data
file_path = os.path.join(dirpath, "training_dataset.pkl")
with open(file_path, "rb") as file:
return pickle.load(file)
raise FileNotFoundError(
f"{FileManager.TRAINING_DATA_PICKLE} not found in {self.input_dir}"
)

def load_testing_data(self) -> dict:
"""Load training data from the dataset directory."""
for dirpath, _, filenames in os.walk(self.input_dir):
# Check if 'training_dataset.pkl' exists in the current directory
if "test_dataset.pkl" in filenames:
# If found, load the data
file_path = os.path.join(dirpath, "testing_dataset.pkl")
with open(file_path, "rb") as file:
return pickle.load(file)
raise FileNotFoundError(
f"{FileManager.TEST_DATA_PICKLE} not found in {self.input_dir}"
)

def load_pickle(self, pickle_path: Path) -> any:
"""Load either training or test data from a pickled file."""
self.logger.debug("Loading object from pickle file from %s", pickle_path)
if not pickle_path.exists():
raise FileNotFoundError(f"Could not find file at path {pickle_path}")

with open(pickle_path, "rb") as pickle_f:
return pickle.load(
pickle_f, fix_imports=True, encoding="ASCII", errors="strict"
)

def validate_input_directory(self, input_dir: Path) -> None:
if not input_dir.exists():
Expand All @@ -81,17 +114,6 @@ def save_pickle(self, obj: any, input_dir: Path) -> None:
with open(input_dir, "wb") as pickle_file:
pickle.dump(obj, pickle_file)

def load_pickle(self, pickle_path: Path) -> any:
"""Load either training or test data from a pickled file."""
self.logger.debug("Loading object from pickle file from %s", pickle_path)
if not pickle_path.exists():
raise FileNotFoundError(f"Could not find file at path {pickle_path}")

with open(pickle_path, "rb") as pickle_f:
return pickle.load(
pickle_f, fix_imports=True, encoding="ASCII", errors="strict"
)

def load_test_data(self) -> dict:
"""Load test data from either the CACHE or input_dir."""
pickle_path = self.input_dir / FileManager.TEST_DATA_PICKLE
Expand Down

0 comments on commit 1565aa9

Please sign in to comment.