Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Drop exclude_file parameters and functions #208

Merged
merged 2 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 0 additions & 144 deletions river_dl/preproc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,102 +276,6 @@ def reshape_for_training(data):
return np.reshape(data, [n_batch * n_seg, seq_len, n_feat])


def get_exclude_start_end(exclude_grp):
"""
get the start and end dates for the exclude group
:param exclude_grp: [dict] dictionary representing the exclude group from
the exclude yml file
:return: [tuple of datetime objects] start date, end date
"""
start = exclude_grp.get("start_date")
if start:
start = datetime.datetime.strptime(start, "%Y-%m-%d")

end = exclude_grp.get("end_date")
if end:
end = datetime.datetime.strptime(end, "%Y-%m-%d")
return start, end


def get_exclude_vars(exclude_grp):
"""
get the variables_to_log to exclude for the exclude group
:param exclude_grp: [dict] dictionary representing the exclude group from
the exclude yml file
:return: [list] variables_to_log to exclude
"""
variable = exclude_grp.get("variable")
if not variable or variable == "both":
return ["seg_tave_water", "seg_outflow"]
elif variable == "temp":
return ["seg_tave_water"]
elif variable == "flow":
return ["seg_outflow"]
else:
raise ValueError("exclude variable must be flow, temp, or both")


def get_exclude_seg_ids(exclude_grp, all_segs):
"""
get the segments to exclude
:param exclude_grp: [dict] dictionary representing the exclude group from
the exclude yml file
:param all_segs: [array] all of the segments. this is needed if we are doing
a reverse exclusion
:return: [list like] the segments to exclude
"""
# ex_segs are the sites to exclude
if "seg_id_nats_ex" in exclude_grp.keys():
ex_segs = exclude_grp["seg_id_nats_ex"]
# exclude all *but* the "seg_id_nats_in"
elif "seg_id_nats_in" in exclude_grp.keys():
ex_mask = ~all_segs.isin(exclude_grp["seg_id_nats_in"])
ex_segs = all_segs[ex_mask]
else:
ex_segs = all_segs
return ex_segs


def exclude_segments(y_data, exclude_segs):
"""
exclude segments from being trained on by setting their weights as zero
:param y_data:[xr dataset] y_dataset data. this is used to get the dimensions
:param exclude_segs: [list] list of segments to exclude in the loss
calculation
:return:
"""
weights = initialize_weights(y_data, 1)
for seg_grp in exclude_segs:
# get the start and end dates is present
start, end = get_exclude_start_end(seg_grp)
exclude_vars = get_exclude_vars(seg_grp)
segs_to_exclude = get_exclude_seg_ids(seg_grp, weights.seg_id_nat)

# loop through the data_vars
for v in exclude_vars:
# set those weights to zero
weights[v].load()
weights[v].loc[
dict(date=slice(start, end), seg_id_nat=segs_to_exclude)
] = 0
return weights


def initialize_weights(y_data, initial_val=1):
"""
initialize all weights with a value.
:param y_data:[xr dataset] y_dataset data. this is used to get the dimensions
:param initial_val: [num] a number to initialize the weights with. should
be between 0 and 1 (inclusive)
:return: [xr dataset] dataset weights initialized with a uniform value
"""
weights = y_data.copy(deep=True)
for v in y_data.data_vars:
weights[v].load()
weights[v].loc[:, :] = initial_val
return weights


def reduce_training_data_random(
data_file,
train_start_date="1980-10-01",
Expand Down Expand Up @@ -600,7 +504,6 @@ def prep_y_data(
time_idx_name="date",
seq_len=365,
log_vars=None,
exclude_file=None,
normalize_y=True,
y_type="obs",
y_std=None,
Expand Down Expand Up @@ -637,7 +540,6 @@ def prep_y_data(
sites will be witheld from training and validation
:param seq_len: [int] length of sequences (e.g., 365)
:param log_vars: [list-like] which variables_to_log (if any) to take log of
:param exclude_file: [str] path to exclude file
:param normalize_y: [bool] whether or not to normalize the y_dataset values
:param y_type: [str] "obs" if observations or "pre" if pretraining
:param y_std: [array-like] standard deviations of y_dataset variables_to_log
Expand Down Expand Up @@ -683,12 +585,6 @@ def prep_y_data(
if log_vars:
y_trn = log_variables(y_trn, log_vars)

# filter pretrain/finetune y_dataset
if exclude_file:
exclude_segs = read_exclude_segs_file(exclude_file)
y_wgts = exclude_segments(y_trn, exclude_segs=exclude_segs)
else:
y_wgts = initialize_weights(y_trn)
# scale y_dataset training data and get the mean and std
# scale the validation partition to benchmark epoch performance
if normalize_y:
Expand All @@ -713,9 +609,6 @@ def prep_y_data(
"y_obs_trn": convert_batch_reshape(
y_trn, spatial_idx_name, time_idx_name, offset=trn_offset, seq_len=seq_len
),
"y_obs_wgts": convert_batch_reshape(
y_wgts, spatial_idx_name, time_idx_name, offset=trn_offset, seq_len=seq_len
),
"y_obs_val": convert_batch_reshape(
y_val, spatial_idx_name, time_idx_name, offset=tst_val_offset, seq_len=seq_len
),
Expand Down Expand Up @@ -768,7 +661,6 @@ def prep_all_data(
dist_type="updown",
catch_prop_file=None,
catch_prop_vars=None,
exclude_file=None,
log_y_vars=False,
out_file=None,
segs=None,
Expand Down Expand Up @@ -823,7 +715,6 @@ def prep_all_data(
left unfilled, the catchment properties will not be included as predictors
:param catch_prop_vars: [list of str] list of catchment properties to use. If
left unfilled and a catchment property file is supplied all variables will be used.
:param exclude_file: [str] path to exclude file
:param log_y_vars: [bool] whether or not to take the log of discharge in
training
:param segs: [list-like] which segments to prepare the data for
Expand Down Expand Up @@ -1005,7 +896,6 @@ def prep_all_data(
time_idx_name=time_idx_name,
seq_len=seq_len,
log_vars=log_y_vars,
exclude_file=exclude_file,
normalize_y=normalize_y,
y_type="obs",
trn_offset = trn_offset,
Expand All @@ -1028,7 +918,6 @@ def prep_all_data(
time_idx_name=time_idx_name,
seq_len=seq_len,
log_vars=log_y_vars,
exclude_file=exclude_file,
normalize_y=normalize_y,
y_type="pre",
y_std=y_obs_data["y_std"],
Expand All @@ -1053,7 +942,6 @@ def prep_all_data(
time_idx_name=time_idx_name,
seq_len=seq_len,
log_vars=log_y_vars,
exclude_file=exclude_file,
normalize_y=normalize_y,
y_type="pre",
trn_offset = trn_offset,
Expand Down Expand Up @@ -1118,35 +1006,3 @@ def prep_adj_matrix(infile, dist_type, dist_idx_name, segs=None, out_file=None):
np.savez_compressed(out_file, dist_matrix=A_hat)
return A_hat


def read_exclude_segs_file(exclude_file):
"""
read the exclude segs file. should be a yml file with start_date and list of
segments to exclude
--
example exclude file:

group_after_2017:
start_date: "2017-10-01"
variable: "temp"
seg_id_nats_ex:
- 1556
- 1569
group_2018_water_year:
start_date: "2017-10-01"
end_date: "2018-10-01"
seg_id_nats_ex:
- 1653
group_all_time:
seg_id_nats_in:
- 1806
- 2030

--
:param exclude_file: [str] exclude segs file
:return: [list] list of dictionaries of segments to exclude. dict keys must
have 'seg_id_nats' and may also have 'start_date' and 'end_date'
"""
with open(exclude_file, "r") as s:
d = yaml.safe_load(s)
return [val for key, val in d.items()]
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_basic.smk
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ rule prep_io_data:
spatial_idx_name='segs_test',
time_idx_name='times_test',
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_gwn.smk
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_pretrain_LSTM.smk
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ rule prep_io_data:
spatial_idx_name='segs_test',
time_idx_name='times_test',
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_rgcn.smk
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_rgcn_hypertune.smk
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down
1 change: 0 additions & 1 deletion workflow_examples/Snakefile_rgcn_pytorch.smk
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ rule prep_io_data:
y_vars_pretrain=config['y_vars_pretrain'],
y_vars_finetune=config['y_vars_finetune'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
train_end_date=config['train_end_date'],
val_start_date=config['val_start_date'],
Expand Down