Skip to content

Commit

Permalink
apply ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
ma595 committed Jul 23, 2024
1 parent ba2aec6 commit 337871d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 40 deletions.
74 changes: 45 additions & 29 deletions Tools/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def collect_data(

# extract data
extr_var = extract_X.var(packdata, ipft)

# extract PFT map
pft_ny = extract_X.pft(packdata, PFT_mask_lai, ipft)
pft_ny = np.resize(pft_ny, (*extr_var.shape[:-1], 1))
Expand Down Expand Up @@ -56,12 +56,14 @@ def combine_data(frames, keys):
raise ValueError("DataFrames have different columns")
check_same = {}
for col in columns:
check_same[col] = all((frame[col] == frames[0][col]).dropna().all() for frame in frames)
same_cols = [col for col, same in check_same.items() if same or col == 'pft']
check_same[col] = all(
(frame[col] == frames[0][col]).dropna().all() for frame in frames
)
same_cols = [col for col, same in check_same.items() if same or col == "pft"]
df = pd.concat([df.drop(columns=same_cols) for df in frames], keys=keys, axis=1)
df.columns = [f"{c}_{k}" for k, c in df.columns]
df = pd.concat([df, frames[0][same_cols]], axis=1)
df = df.drop(columns=['pft']).dropna()
df = df.drop(columns=["pft"]).dropna()
return df


Expand All @@ -86,8 +88,8 @@ def MLmap_multidim(
col_type = "None"
type_val = "None"
combineXY = combine_XY
Y = combineXY.filter(regex='^Y_')

Y = combineXY.filter(regex="^Y_")
X = combineXY.drop(columns=Y.columns)

# combine_XY=pd.get_dummies(combine_XY) # one-hot encoded
Expand Down Expand Up @@ -142,8 +144,20 @@ def MLmap_multidim(
return MLeval.evaluation_map(Global_Predicted_Y_map, Y, PFT_mask)


def plot_eval_results(Global_Predicted_Y_map, ipool, pool_map, combineXY, predY_train, varname, ind, ii, ipft, PFT_mask, resultpath, logfile):

def plot_eval_results(
Global_Predicted_Y_map,
ipool,
pool_map,
combineXY,
predY_train,
varname,
ind,
ii,
ipft,
PFT_mask,
resultpath,
logfile,
):
# evaluation
R2, RMSE, slope, reMSE, dNRMSE, sNRMSE, iNRMSE, f_SB, f_SDSD, f_LSC = (
MLeval.evaluation_map(Global_Predicted_Y_map, pool_map, ipft, PFT_mask)
Expand Down Expand Up @@ -257,29 +271,31 @@ def MLloop(
if ipft in ii["skip_loop"]["pft"]:
continue

dim_ind, = zip(ii["dim_loop"], ind)

comb_ds[ipool].append((
collect_data(
packdata,
ivar,
ipool,
PFT_mask_lai,
ipft,
varname,
ind,
ii,
labx,
varlist,
logfile,
),
f"{varname}_{dim_ind[0]}_{dim_ind[1]}"
))
(dim_ind,) = zip(ii["dim_loop"], ind)

comb_ds[ipool].append(
(
collect_data(
packdata,
ivar,
ipool,
PFT_mask_lai,
ipft,
varname,
ind,
ii,
labx,
varlist,
logfile,
),
f"{varname}_{dim_ind[0]}_{dim_ind[1]}",
)
)
break

# close&save netCDF file
restnc.close()

if len(comb_ds[ipool]) > 3:
break

Expand All @@ -288,7 +304,7 @@ def MLloop(
for ipool, vals in comb_ds.items():
df = combine_data(*zip(*vals))
df.to_csv(f"{resultpath}/{ipool}.csv")

res = MLmap_multidim(
packdata,
df,
Expand All @@ -302,5 +318,5 @@ def MLloop(
missVal,
)
results.append(res)

return pd.concat(results, keys=comb_ds.keys(), names=["component"])
2 changes: 1 addition & 1 deletion Tools/extract_X.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ def var(packdata, ipft):
extr_var.append(extracted_var.reshape(-1, len(packdata.Nlat), 1))
com_shape = max(map(np.shape, extr_var))
extr_var = [np.resize(a, com_shape) for a in extr_var]
return np.concatenate(extr_var, axis=-1)
return np.concatenate(extr_var, axis=-1)
8 changes: 5 additions & 3 deletions Tools/readvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def readvar(varlist, config, logfile):
# packdata.Tamp = packdata.Tmax - packdata.Tmin

# 0.1.2 Other climatic variables (Rainf,Snowf,Qair,Psurf,SWdown,LWdown)
packdata.update((k, (["year", "month", "lat", "lon"], adict[f"MY{k}"])) for k in varname_clim)
packdata.update(
(k, (["year", "month", "lat", "lon"], adict[f"MY{k}"])) for k in varname_clim
)
# for index in range(len(varname_clim)):
# if varname_clim[index] == "Tair":
# continue
Expand Down Expand Up @@ -167,12 +169,12 @@ def readvar(varlist, config, logfile):
da[da == predvar[ipred]["missing_value"]] = np.nan
if isinstance(da, np.ma.masked_array):
da = da.filled(np.nan)
packdata[rename[ivar]] = (["veget", "lat", "lon"][-da.ndim:], da)
packdata[rename[ivar]] = (["veget", "lat", "lon"][-da.ndim :], da)

ds = xarray.Dataset(packdata)

# 0.3 Interactions between variables
ds['interx'] = ds.Tair * ds.Rainf
ds["interx"] = ds.Tair * ds.Rainf
# packdata.interx2 = packdata.Temp_GS * packdata.Pre_GS

ds.attrs.update(
Expand Down
19 changes: 12 additions & 7 deletions Tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@
##@retval predY predicted Y
def training_BAT(X, Y, logfile, loocv):
print("Data shapes: ", X.shape, Y.shape)

# run the KMeans algorithm to find the cluster centers, and resample the data
mod = KMeans(n_clusters=3)
lab = mod.fit_predict(Y)
count = Counter(lab)
check.display("Counter(lab):" + str(count), logfile)
over_samples = SMOTE()
over_samples_X, over_samples_y = over_samples.fit_resample(pd.concat([X, Y], axis=1), lab)
check.display(
"Counter(over_samples_y):" + str(Counter(over_samples_y)), logfile
over_samples_X, over_samples_y = over_samples.fit_resample(
pd.concat([X, Y], axis=1), lab
)
check.display("Counter(over_samples_y):" + str(Counter(over_samples_y)), logfile)
X = over_samples_X[X.columns]
Y = over_samples_X[Y.columns]
print("Data shapes after resampling: ", X.shape, Y.shape)
Expand All @@ -62,10 +62,15 @@ def training_BAT(X, Y, logfile, loocv):
# optimizer=optim.Adam,
# # device="cuda",
# )
model = MLPRegressor(hidden_layer_sizes=(64, 64), max_iter=100,
learning_rate='invscaling', learning_rate_init=0.1, verbose=True)
model = MLPRegressor(
hidden_layer_sizes=(64, 64),
max_iter=100,
learning_rate="invscaling",
learning_rate_init=0.1,
verbose=True,
)

model.fit(X, Y)
predY = model.predict(X)

return model, predY

0 comments on commit 337871d

Please sign in to comment.