From 1e4465b3e8c211acaa0bd214eca2f9d11738317a Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Tue, 18 Jun 2024 09:49:04 +0100 Subject: [PATCH] initial attempt at sat vis --- .../gpjax_models/gpjax_models/models/vis.py | 54 +++++++++++-------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/containers/cleanair/gpjax_models/gpjax_models/models/vis.py b/containers/cleanair/gpjax_models/gpjax_models/models/vis.py index 21dce2225..dd6688a6a 100644 --- a/containers/cleanair/gpjax_models/gpjax_models/models/vis.py +++ b/containers/cleanair/gpjax_models/gpjax_models/models/vis.py @@ -8,7 +8,7 @@ from stdata.vis.spacetime import SpaceTimeVisualise -directory_path = "/clean-air/clean-air-infrastructure/containers/cleanair/gpjax_models/data/dgp_small_inducing_and_maxiter " +directory_path = Path('/Users/oliverhamelijnck/Downloads/dataset/') # Create the directory if it doesn't exist if not os.path.exists(directory_path): @@ -18,27 +18,27 @@ def load_data(root): with open( str( - "/clean-air/clean-air-infrastructure/containers/cleanair/gpjax_models/data/dgp_small_inducing_and_maxiter/3/dataset/training_dataset.pkl" + directory_path / "training_dataset.pkl" ), "rb", ) as file: training_data = pd.read_pickle(file) with open( str( - "/clean-air/clean-air-infrastructure/containers/cleanair/gpjax_models/data/dgp_small_inducing_and_maxiter/3/dataset/test_dataset.pkl" + directory_path/ "test_dataset.pkl" ), "rb", ) as file: testing_data = pd.read_pickle(file) # Load raw data using pickle with open( - "/clean-air/clean-air-infrastructure/raw_data_3.pkl", + directory_path / "raw_data_3.pkl", "rb", ) as file: raw_data = pd.read_pickle(file) with open( - "/clean-air/clean-air-infrastructure/test_true_dataset.pkl", + directory_path / "test_true_dataset.pkl", "rb", ) as file: true_y = pd.read_pickle(file) @@ -74,7 +74,7 @@ def load_data(root): def load_results(root): with open( - str(root / "mrdgp_production_results" / "predictions_mrdgp.pkl"), "rb" + str(root / "predictions_mrdgp_3.pkl"), "rb" ) as file: results = pd.read_pickle(file) return results @@ -90,11 +90,11 @@ def fix_df_columns(df): if __name__ == "__main__": - data_root = Path("containers/cleanair/gpjax_models/data") + data_root = directory_path training_data, testing_data, raw_data, day_3_gt = load_data(data_root) day_3_gt = day_3_gt.reset_index(drop=True) - train_laqn_df = fix_df_columns_dropna(raw_data["train"]["laqn"]["df"]) + train_laqn_df = fix_df_columns(raw_data["train"]["laqn"]["df"]) test_laqn_df = fix_df_columns(raw_data["test"]["laqn"]["df"]) true_val = fix_df_columns(day_3_gt) # test_laqn_true_values = true_y @@ -120,20 +120,30 @@ def fix_df_columns(df): # ) train_end = train_laqn_df["epoch"].max() laqn_df = pd.concat([train_laqn_df, test_laqn_df]) - #'geom' is the column containing Shapely Point geometries - hexgrid_df["geom"] = gpd.points_from_xy(hexgrid_df["lon"], hexgrid_df["lat"]) - - # Buffer each Point geometry by 0.002 - hexgrid_df["geom"] = hexgrid_df["geom"].apply(lambda point: point.buffer(0.002)) - - # Create a GeoDataFrame using the 'geom' column - hexgrid_gdf = gpd.GeoDataFrame(hexgrid_df, geometry="geom") - hexgrid_df["pred"] = results["predictions"]["hexgrid"]["mu"][0].T - # hexgrid_df["pred"] = hexgrid_df["traffic"] - hexgrid_df["var"] = np.squeeze(results["predictions"]["hexgrid"]["var"][0]) - vis_obj = SpaceTimeVisualise( - laqn_df, hexgrid_df, geopandas_flag=True, test_start=train_end - ) + if False: + #'geom' is the column containing Shapely Point geometries + hexgrid_df["geom"] = gpd.points_from_xy(hexgrid_df["lon"], hexgrid_df["lat"]) + + # Buffer each Point geometry by 0.002 + hexgrid_df["geom"] = hexgrid_df["geom"].apply(lambda point: point.buffer(0.002)) + + # Create a GeoDataFrame using the 'geom' column + hexgrid_gdf = gpd.GeoDataFrame(hexgrid_df, geometry="geom") + hexgrid_df["pred"] = results["predictions"]["hexgrid"]["mu"][0].T + # hexgrid_df["pred"] = hexgrid_df["traffic"] + hexgrid_df["var"] = np.squeeze(results["predictions"]["hexgrid"]["var"][0]) + else: + hexgrid_df = None + + sat_df = fix_df_columns(raw_data["train"]['sat']['df']) + # TODO: NEED TO CHECK!! this should match is handling the satllite data + sat_df = sat_df[['lon', 'lat', 'NO2', 'epoch', 'box_id']].groupby(['epoch', 'box_id']).mean().reset_index() + # copy predictions + sat_df['pred'] = results["predictions"]['sat']['mu'][0] + sat_df['var'] = results["predictions"]['sat']['var'][0] + sat_df['observed'] = sat_df['NO2'] + + vis_obj = SpaceTimeVisualise( laqn_df, hexgrid_df, sat_df=sat_df, geopandas_flag=True, test_start=train_end) # Show the visualization vis_obj.show()