Skip to content

Commit

Permalink
Merge branch 'redesing_system_add_sat_vis' into redesing_system
Browse files Browse the repository at this point in the history
  • Loading branch information
Sueda Ciftci committed Sep 13, 2024
2 parents aee7e2d + 1e4465b commit 50abe4b
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions containers/cleanair/gpjax_models/gpjax_models/models/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from stdata.vis.spacetime import SpaceTimeVisualise


directory_path = "/clean-air/clean-air-infrastructure/containers/cleanair/gpjax_models/data/dgp_small_inducing_and_maxiter "
directory_path = Path('/Users/oliverhamelijnck/Downloads/dataset/')

# Create the directory if it doesn't exist
if not os.path.exists(directory_path):
Expand All @@ -18,27 +18,27 @@
def load_data(root):
with open(
str(
"/clean-air/clean-air-infrastructure/containers/cleanair/gpjax_models/data/dgp_small_inducing_and_maxiter/3/dataset/training_dataset.pkl"
directory_path / "training_dataset.pkl"
),
"rb",
) as file:
training_data = pd.read_pickle(file)
with open(
str(
"/clean-air/clean-air-infrastructure/containers/cleanair/gpjax_models/data/dgp_small_inducing_and_maxiter/3/dataset/test_dataset.pkl"
directory_path/ "test_dataset.pkl"
),
"rb",
) as file:
testing_data = pd.read_pickle(file)
# Load raw data using pickle
with open(
"/clean-air/clean-air-infrastructure/raw_data_3.pkl",
directory_path / "raw_data_3.pkl",
"rb",
) as file:
raw_data = pd.read_pickle(file)

with open(
"/clean-air/clean-air-infrastructure/test_true_dataset.pkl",
directory_path / "test_true_dataset.pkl",
"rb",
) as file:
true_y = pd.read_pickle(file)
Expand Down Expand Up @@ -74,7 +74,7 @@ def load_data(root):

def load_results(root):
with open(
str(root / "mrdgp_production_results" / "predictions_mrdgp.pkl"), "rb"
str(root / "predictions_mrdgp_3.pkl"), "rb"
) as file:
results = pd.read_pickle(file)
return results
Expand All @@ -90,11 +90,11 @@ def fix_df_columns(df):


if __name__ == "__main__":
data_root = Path("containers/cleanair/gpjax_models/data")
data_root = directory_path

training_data, testing_data, raw_data, day_3_gt = load_data(data_root)
day_3_gt = day_3_gt.reset_index(drop=True)
train_laqn_df = fix_df_columns_dropna(raw_data["train"]["laqn"]["df"])
train_laqn_df = fix_df_columns(raw_data["train"]["laqn"]["df"])
test_laqn_df = fix_df_columns(raw_data["test"]["laqn"]["df"])
true_val = fix_df_columns(day_3_gt)
# test_laqn_true_values = true_y
Expand All @@ -120,20 +120,30 @@ def fix_df_columns(df):
# )
train_end = train_laqn_df["epoch"].max()
laqn_df = pd.concat([train_laqn_df, test_laqn_df])
#'geom' is the column containing Shapely Point geometries
hexgrid_df["geom"] = gpd.points_from_xy(hexgrid_df["lon"], hexgrid_df["lat"])

# Buffer each Point geometry by 0.002
hexgrid_df["geom"] = hexgrid_df["geom"].apply(lambda point: point.buffer(0.002))

# Create a GeoDataFrame using the 'geom' column
hexgrid_gdf = gpd.GeoDataFrame(hexgrid_df, geometry="geom")
hexgrid_df["pred"] = results["predictions"]["hexgrid"]["mu"][0].T
# hexgrid_df["pred"] = hexgrid_df["traffic"]
hexgrid_df["var"] = np.squeeze(results["predictions"]["hexgrid"]["var"][0])
vis_obj = SpaceTimeVisualise(
laqn_df, hexgrid_df, geopandas_flag=True, test_start=train_end
)
if False:
#'geom' is the column containing Shapely Point geometries
hexgrid_df["geom"] = gpd.points_from_xy(hexgrid_df["lon"], hexgrid_df["lat"])

# Buffer each Point geometry by 0.002
hexgrid_df["geom"] = hexgrid_df["geom"].apply(lambda point: point.buffer(0.002))

# Create a GeoDataFrame using the 'geom' column
hexgrid_gdf = gpd.GeoDataFrame(hexgrid_df, geometry="geom")
hexgrid_df["pred"] = results["predictions"]["hexgrid"]["mu"][0].T
# hexgrid_df["pred"] = hexgrid_df["traffic"]
hexgrid_df["var"] = np.squeeze(results["predictions"]["hexgrid"]["var"][0])
else:
hexgrid_df = None

sat_df = fix_df_columns(raw_data["train"]['sat']['df'])
# TODO: NEED TO CHECK!! this should match is handling the satllite data
sat_df = sat_df[['lon', 'lat', 'NO2', 'epoch', 'box_id']].groupby(['epoch', 'box_id']).mean().reset_index()
# copy predictions
sat_df['pred'] = results["predictions"]['sat']['mu'][0]
sat_df['var'] = results["predictions"]['sat']['var'][0]
sat_df['observed'] = sat_df['NO2']

vis_obj = SpaceTimeVisualise( laqn_df, hexgrid_df, sat_df=sat_df, geopandas_flag=True, test_start=train_end)

# Show the visualization
vis_obj.show()

0 comments on commit 50abe4b

Please sign in to comment.