Skip to content

Commit

Permalink
more additions for wrf
Browse files Browse the repository at this point in the history
  • Loading branch information
allibco committed Sep 20, 2024
1 parent 7cb79a2 commit 067b631
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
11 changes: 7 additions & 4 deletions ldcpy/calcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,10 +1496,13 @@ def zscore_percent_significant(self) -> np.ndarray:
sorted_pvals = np.sort(pvals_array).flatten()
fdr_zscore = 0.01
p = np.argwhere(sorted_pvals <= fdr_zscore * np.arange(1, pvals.size + 1) / pvals.size)
pval_cutoff = sorted_pvals[p[len(p) - 1]]
if not (pval_cutoff.size == 0):
sig_locs = np.argwhere(pvals <= pval_cutoff)
percent_sig = 100 * np.size(sig_locs, 0) / pvals.size
if p.size > 0:
pval_cutoff = sorted_pvals[p[len(p) - 1]]
if not (pval_cutoff.size == 0):
sig_locs = np.argwhere(pvals <= pval_cutoff)
percent_sig = 100 * np.size(sig_locs, 0) / pvals.size
else:
percent_sig = 0
else:
percent_sig = 0
self._zscore_percent_significant = percent_sig
Expand Down
43 changes: 25 additions & 18 deletions ldcpy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def get_calcs(self, da, data_type):
return raw_data

def get_plot_data(self, raw_data_1, raw_data_2=None):

time_dim = self._ds.cf.coordinates['time'][0]

if self._calc_type == 'diff':
plot_data = raw_data_1 - raw_data_2
plot_data.attrs = raw_data_1.attrs
Expand All @@ -198,7 +201,7 @@ def get_plot_data(self, raw_data_1, raw_data_2=None):
'odds_positive',
]:
plot_attrs = plot_data.attrs
plot_data = plot_data.groupby(self._group_by).mean(dim='time')
plot_data = plot_data.groupby(self._group_by).mean(dim=time_dim)
plot_data.attrs = plot_attrs

if self._transform == 'none':
Expand Down Expand Up @@ -374,6 +377,10 @@ def spatial_plot(self, da_sets, titles, data_type):
if np.isnan(cy_datas).all():
all_nan_flag = 1

if data_type == 'wrf':
if nan_inf_flag == 1 or all_nan_flag == 1:
axs[i].set_facecolor('#39ff14')

cyxr = xr.DataArray(data=cy_datas)

if self._cmax is not None:
Expand Down Expand Up @@ -418,9 +425,6 @@ def spatial_plot(self, da_sets, titles, data_type):
)

elif data_type == 'cam-fv':
# psets[i] = axs[i].imshow(
# img=flipud(no_inf_data_set), transform=ccrs.PlateCarree(), cmap=mymap
# )
psets[i] = axs[i].imshow(
img=flipud(no_inf_data_set), transform=ccrs.PlateCarree(), cmap=mymap
)
Expand Down Expand Up @@ -591,32 +595,31 @@ def time_series_plot(
self,
da_sets,
titles,
time_dim,
):
"""
time series plot
"""


time_dim = da_sets[0].cf.coordinates['time'][0]
data_type = da_sets[0].attrs['data_type']

group_string = 'time.year'
group_string = time_dim + '.year'
xlabel = 'date'
tick_interval = int(da_sets.size / da_sets.sets.size / 5) + 1
if da_sets.size / da_sets.sets.size == 1:
tick_interval = 1
if self._group_by == 'time.dayofyear':
if self._group_by == 'time.dayofyear' or self._group_by == 'Time.dayofyear':

group_string = 'dayofyear'
xlabel = 'Day of Year'
elif self._group_by == 'time.month':
elif self._group_by == 'time.month' or self._group_by == 'Time.month':
group_string = 'month'
xlabel = 'Month'
tick_interval = 1
elif self._group_by == 'time.year':
elif self._group_by == 'time.year' or self._group_by == 'Time.year':
group_string = 'year'
xlabel = 'Year'
elif self._group_by == 'time.day':
elif self._group_by == 'time.day' or self._group_by == 'time.day':
group_string = 'day'
xlabel = 'Day'

Expand Down Expand Up @@ -656,6 +659,8 @@ def time_series_plot(
}
)


print(group_string)
for i in range(da_sets.sets.size):
if self._group_by is not None:
plt.plot(
Expand Down Expand Up @@ -703,7 +708,7 @@ def time_series_plot(
mpl.pyplot.xticks(
np.arange(min(da_sets[group_string]), max(da_sets[group_string]) + 1, tick_interval)
)
if self._group_by == 'time.month':
if self._group_by == 'time.month' or self._group_by == 'Time.month':
int_labels = plt.xticks()[0]
month_labels = [
calendar.month_name[i] for i in int_labels if calendar.month_name[i] != ''
Expand Down Expand Up @@ -747,11 +752,13 @@ def get_calc_label(self, calc, data, data_type):
percent_sig = lm.Datasetcalcs(
(data), data_type, [time_dim], weighted=self._weighted
).get_single_calc('zscore_percent_significant')

if abs(zscore_cutoff[0]) > 0.01:
calc_name = f'{calc}: cutoff {zscore_cutoff[0]:.2f}, % sig: {percent_sig:.2f}'
else:
calc_name = f'{calc}: cutoff {zscore_cutoff[0]:.2e}, % sig: {percent_sig:.2f}'
if percent_sig == 0:
calc_name = f'{calc}'
else:
if abs(zscore_cutoff[0]) > 0.01:
calc_name = f'{calc}: cutoff {zscore_cutoff[0]:.2f}, % sig: {percent_sig:.2f}'
else:
calc_name = f'{calc}: cutoff {zscore_cutoff[0]:.2e}, % sig: {percent_sig:.2f}'

elif calc == 'mean' and self._plot_type == 'spatial' and self._calc_type == 'raw':
if self._weighted:
Expand Down Expand Up @@ -1217,7 +1224,7 @@ class in ldcpy.plot, for more information about the available calcs see ldcpy.Da
if plot_type == 'spatial':
mp.spatial_plot(plot_dataset, titles, ds.data_type)
elif plot_type == 'time_series':
mp.time_series_plot(plot_dataset, titles)
mp.time_series_plot(plot_dataset, titles, time_dim)
elif plot_type == 'histogram':
mp.hist_plot(plot_dataset, titles)
elif plot_type == 'periodogram':
Expand Down
5 changes: 3 additions & 2 deletions ldcpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ def collect_datasets(data_type, varnames, list_of_ds, labels, coords_ds=None, **
indx = np.where(latlon_found > 1)[0]
assert len(indx) == len(list_of_ds), 'ERROR: WRF datasets must contain XLAT and XLONG'
else: #has a coords ds
#copy corrds to the datasets
ds_notime = coords_ds.drop_dims("Time")
#copy coords to EACH of the datasets
for i, myds in enumerate (list_of_ds):
ds_new = myds.assign_coords(coords_ds.coords)
ds_new = myds.assign_coords(ds_notime.coords)
list_of_ds[i] = ds_new.copy(deep=True)

# weights?
Expand Down

0 comments on commit 067b631

Please sign in to comment.