Skip to content

Commit

Permalink
handle var extraction for data with multiple timesteps
Browse files Browse the repository at this point in the history
  • Loading branch information
tztsai authored and ma595 committed Jul 23, 2024
1 parent e75c5d7 commit ba2aec6
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 55 deletions.
15 changes: 7 additions & 8 deletions Tools/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,22 @@ def collect_data(
extr_var = extract_X.var(packdata, ipft)

# extract PFT map
pft_ny = extract_X.pft(packdata, PFT_mask_lai, ipft).reshape(len(packdata.Nlat), 1)
pft_ny = extract_X.pft(packdata, PFT_mask_lai, ipft)
pft_ny = np.resize(pft_ny, (*extr_var.shape[:-1], 1))

# extract Y
pool_arr = np.full(len(packdata.Nlat), np.nan)
pool_map = np.squeeze(ivar)[
tuple(i - 1 for i in ind)
] # all indices start from 1, but python loop starts from 0
pool_map[pool_map >= 1e18] = np.nan
if "format" in varlist["resp"] and varlist["resp"]["format"] == "compressed":
for cc in range(len(packdata.Nlat)):
pool_arr[cc] = pool_map.flatten()[cc]
pool_arr = pool_map.flatten()
else:
for cc in range(len(packdata.Nlat)):
pool_arr[cc] = pool_map[packdata.Nlat[cc], packdata.Nlon[cc]]
pool_arr = pool_map[packdata.Nlat, packdata.Nlon]
extracted_Y = np.resize(pool_arr, (*extr_var.shape[:-1], 1))

extracted_Y = np.reshape(pool_arr, (len(packdata.Nlat), 1))
extr_all = np.concatenate((extracted_Y, extr_var, pft_ny), axis=1)
extr_all = np.concatenate((extracted_Y, extr_var, pft_ny), axis=-1)
extr_all = extr_all.reshape(-1, extr_all.shape[-1])
return DataFrame(extr_all, columns=labx) # convert the array into dataframe


Expand Down
1 change: 1 addition & 0 deletions Tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import matplotlib
import numpy as np
import pandas as pd
import xarray as xr
from netCDF4 import Dataset
import calendar

Expand Down
57 changes: 17 additions & 40 deletions Tools/extract_X.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,64 +18,41 @@


##@param[in] packdata packaged data
##@param[in] var_ind index of variable
##@param[in] VarName list of variables' names
##@param[in] var_name name of variable
##@retval VarN variable values of selected pixels
def extract_X(packdata, var_ind, VarName):
Nlat = packdata.Nlat
Nlon = packdata.Nlon
varN = np.full(len(Nlat), np.nan)
var_data = packdata[VarName[var_ind]]
for cc in range(0, len(Nlat)):
varN[cc] = var_data[Nlat[cc], Nlon[cc]]
return varN
def extract_X(packdata, var_name):
var = packdata[var_name].values
return var[..., packdata.Nlat, packdata.Nlon]


##@param[in] packdata packaged data
##@param[in] var_ind index of variable
##@param[in] VarName list of variables' names
##@param[in] var_name name of variable
##@param[in] px index of PFT
##@retval VarN variable values of selected pixels
def extract_XN(packdata, var_ind, VarName, px):
Nlat = packdata.Nlat
Nlon = packdata.Nlon
varN = np.full(len(Nlat), np.nan)
var_data = packdata[VarName[var_ind]]
var_pft_map = var_data[px - 1]
for cc in range(0, len(Nlat)):
varN[cc] = var_pft_map[Nlat[cc], Nlon[cc]]
return varN
def extract_XN(packdata, var_name, px):
var = packdata[var_name].values
return var[px - 1, packdata.Nlat, packdata.Nlon]


##@param[in] packdata packaged data
##@param[in] PFT_mask PFT mask
##@param[in] px index of PFT
##@retval VarN variable values of selected pixels
def pft(packdata, PFT_mask, px):
Nlat = packdata.Nlat
Nlon = packdata.Nlon
varN = np.full(len(Nlat), np.nan)
for cc in range(0, len(Nlat)):
varN[cc] = PFT_mask[px - 1, Nlat[cc], Nlon[cc]]
return varN
return PFT_mask[px - 1, packdata.Nlat, packdata.Nlon]


##@param[in] packdata packaged data
##@param[in] ipft ith pft
##@retval extr_var extracked data
def var(packdata, ipft):
extr_var = []
for indx in range(packdata.Nv_total):
if indx < packdata.Nv_nopft:
extracted_var = np.reshape(
extract_X(packdata, indx, packdata.var_pred_name),
(len(packdata.Nlat), 1),
)
extr_var.append(extracted_var)
for var_name in packdata.data_vars:
if "veget" not in packdata[var_name].dims:
extracted_var = extract_X(packdata, var_name)
else:
extracted_var = np.reshape(
extract_XN(packdata, indx, packdata.var_pred_name, ipft),
(len(packdata.Nlat), 1),
)
extr_var.append(extracted_var)
return np.hstack(extr_var)
extracted_var = extract_XN(packdata, var_name, ipft)
extr_var.append(extracted_var.reshape(-1, len(packdata.Nlat), 1))
com_shape = max(map(np.shape, extr_var))
extr_var = [np.resize(a, com_shape) for a in extr_var]
return np.concatenate(extr_var, axis=-1)
14 changes: 7 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@
if "4" in itask:
# ML extrapolation

var_pred_name1 = varlist["pred"]["allname"]
var_pred_name2 = varlist["pred"]["allname_pft"]
var_pred_name = var_pred_name1 + var_pred_name2
# var_pred_name1 = varlist["pred"]["allname"]
# var_pred_name2 = varlist["pred"]["allname_pft"]
# var_pred_name = var_pred_name1 + var_pred_name2
# packdata.Nv_nopft = len(var_pred_name1)
# packdata.Nv_total = len(var_pred_name)
# packdata.var_pred_name = var_pred_name
Expand All @@ -187,13 +187,13 @@
# packdata.attrs['Nlat'] = np.trunc((90 - IDx[:, 0]) / packdata.lat_reso).astype(int)
# packdata.attrs['Nlon'] = np.trunc((180 + IDx[:, 1]) / packdata.lon_reso).astype(int)
packdata.attrs.update(
Nv_nopft=len(var_pred_name1),
Nv_total=len(var_pred_name),
var_pred_name=var_pred_name,
# Nv_nopft=len(var_pred_name1),
# Nv_total=len(var_pred_name),
# var_pred_name=var_pred_name,
Nlat=np.trunc((90 - IDx[:, 0]) / packdata.lat_reso).astype(int),
Nlon=np.trunc((180 + IDx[:, 1]) / packdata.lon_reso).astype(int),
)
labx = ["Y"] + var_pred_name + ["pft"]
labx = ["Y"] + list(packdata.data_vars) + ["pft"]

# copy the restart file to be modified
targetfile = (
Expand Down

0 comments on commit ba2aec6

Please sign in to comment.