Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Grid Vars & Extra Variables for llcreader #308

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 76 additions & 54 deletions xmitgcm/llcreader/llcmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_grid_metadata():

return grid_metadata

def _get_var_metadata():
def _get_var_metadata(extra_variables=None):
# The LLC run data comes with zero metadata. So we import metadata from
# the xmitgcm package.
from ..variables import state_variables, package_state_variables
Expand All @@ -50,6 +50,8 @@ def _get_var_metadata():
var_metadata = state_variables.copy()
var_metadata.update(package_state_variables)
var_metadata.update(available_diags)
if extra_variables is not None:
var_metadata.update(extra_variables)

# even the file names from the LLC data differ from standard MITgcm output
aliases = {'Eta': 'ETAN', 'PhiBot': 'PHIBOT', 'Salt': 'SALT',
Expand All @@ -62,18 +64,11 @@ def _get_var_metadata():

return var_metadata

_VAR_METADATA = _get_var_metadata()

def _is_vgrid(vname):
# check for 1d, vertical grid variables
dims = _VAR_METADATA[vname]['dims']
return len(dims)==1 and dims[0][0]=='k'

def _get_variable_point(vname, mask_override):
def _get_variable_point(vname, dims, mask_override):
# fix for https://github.com/MITgcm/xmitgcm/issues/191
if vname in mask_override:
return mask_override[vname]
dims = _VAR_METADATA[vname]['dims']
if 'i' in dims and 'j' in dims:
point = 'c'
elif 'i_g' in dims and 'j' in dims:
Expand All @@ -86,30 +81,6 @@ def _get_variable_point(vname, mask_override):
raise ValueError("Variable `%s` is not a horizontal variable." % vname)
return point

def _get_scalars_and_vectors(varnames, type):

for vname in varnames:
if vname not in _VAR_METADATA:
raise ValueError("Varname `%s` not found in metadata." % vname)

if type != 'latlon':
return varnames, []

scalars = []
vector_pairs = []
for vname in varnames:
meta = _VAR_METADATA[vname]
try:
mate = meta['attrs']['mate']
if mate not in varnames:
raise ValueError("Vector pairs are required to create "
"latlon type datasets. Varname `%s` is "
"missing its vector mate `%s`"
% vname, mate)
vector_pairs.append((vname, mate))
varnames.remove(mate)
except KeyError:
scalars.append(vname)

def _decompress(data, mask, dtype):
data_blank = np.full_like(mask, np.nan, dtype=dtype)
Expand Down Expand Up @@ -450,7 +421,7 @@ def _chunks(l, n):


def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces,
dtype, mask_override, domain, pad_before, pad_after):
dtype, mask_override, domain, pad_before, pad_after, dims):

fs, path = store.get_fs_and_full_path(varname, iternum)

Expand All @@ -468,7 +439,7 @@ def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces,
if (store.shrunk and iternum is not None) or \
(store.shrunk_grid and iternum is None):
# the store tells us whether we need a mask or not
point = _get_variable_point(varname, mask_override)
point = _get_variable_point(varname, dims, mask_override)
mykey = nx if domain == 'global' else f'{domain}_{nx}'
index = all_index_data[mykey][point]
zgroup = store.open_mask_group()
Expand Down Expand Up @@ -594,6 +565,7 @@ class BaseLLCModel:
varnames = []
grid_varnames = []
mask_override = {}
var_metadata = None
domain = 'global'
pad_before = [0]*_nfacets
pad_after = [0]*_nfacets
Expand Down Expand Up @@ -632,6 +604,37 @@ def _dtype(self,varname=None):
elif isinstance(self.dtype,dict):
return np.dtype(self.dtype[varname])

def _is_vgrid(self, vname):
# check for 1d, vertical grid variables
dims = self.var_metadata[vname]['dims']
return len(dims)==1 and dims[0][0]=='k'


def _get_scalars_and_vectors(self, varnames, type):

for vname in varnames:
if vname not in self.var_metadata:
raise ValueError("Varname `%s` not found in metadata." % vname)

if type != 'latlon':
return varnames, []

scalars = []
vector_pairs = []
for vname in varnames:
meta = self.var_metadata[vname]
try:
mate = meta['attrs']['mate']
if mate not in varnames:
raise ValueError("Vector pairs are required to create "
"latlon type datasets. Varname `%s` is "
"missing its vector mate `%s`"
% vname, mate)
vector_pairs.append((vname, mate))
varnames.remove(mate)
except KeyError:
scalars.append(vname)

def _get_kp1_levels(self,k_levels):
# determine kp1 levels
# get borders to all k (center) levels
Expand Down Expand Up @@ -698,10 +701,11 @@ def _key_and_task(n_k, these_klevels, n_iter=None, iternum=None):
key = name, n_k, 0, 0, 0
else:
key = name, n_iter, n_k, 0, 0, 0
dims = self.var_metadata[varname]['dims']
task = (_get_facet_chunk, self.store, varname, iternum,
nfacet, these_klevels, self.nx, self.nz, self.nface,
dtype, self.mask_override, self.domain,
self.pad_before, self.pad_after)
self.pad_before, self.pad_after, dims)
return key, task

if iters is not None:
Expand Down Expand Up @@ -729,7 +733,7 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize):
name = '-'.join([varname, token])
dtype = self._dtype(varname)

nz = self.nz if _VAR_METADATA[varname]['dims'] != ['k_p1'] else self.nz+1
nz = self.nz if self.var_metadata[varname]['dims'] != ['k_p1'] else self.nz+1
task = (_get_1d_chunk, self.store, varname,
list(klevels), nz, dtype)

Expand All @@ -740,12 +744,12 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize):

def _get_facet_data(self, varname, iters, klevels, k_chunksize):
# needs facets to be outer index of nested lists
dims = _VAR_METADATA[varname]['dims']
dims = self.var_metadata[varname]['dims']

if len(dims)==2:
klevels = [0,]

if _is_vgrid(varname):
if self._is_vgrid(varname):
data_facets = self._dask_array_vgrid(varname,klevels,k_chunksize)
else:
data_facets = [self._dask_array(nfacet, varname, iters, klevels, k_chunksize)
Expand Down Expand Up @@ -776,18 +780,19 @@ def _check_iters(self, iters):
if not set(iters) <= set(self.iters):
msg = "Some requested iterations may not exist, you may need to change 'iters'"
warnings.warn(msg, RuntimeWarning)

elif self.iter_start is not None and self.iter_step is not None:
for iter in iters:
if (iter - self.iter_start) % self.iter_step:
msg = "Some requested iterations may not exist, you may need to change 'iters'"
warnings.warn(msg, RuntimeWarning)
break


def get_dataset(self, varnames=None, iter_start=None, iter_stop=None,
iter_step=None, iters=None, k_levels=None, k_chunksize=1,
type='faces', read_grid=True, grid_vars_to_coords=True):
type='faces', read_grid=True, grid_vars_to_coords=True,
extra_variables=None):
"""
Create an xarray Dataset object for this model.

Expand Down Expand Up @@ -817,6 +822,22 @@ def get_dataset(self, varnames=None, iter_start=None, iter_stop=None,
Whether to read the grid info
grid_vars_to_coords : bool, optional
Whether to promote grid variables to coordinate status
extra_variables : dict, optional
Allow to pass variables not listed in the variables.py
or in available_diagnostics.log.
extra_variables must be a dict containing the variable names as keys with
the corresponging values being a dict with the keys being dims and attrs.

Syntax:
extra_variables = dict(varname = dict(dims=list_of_dims, attrs=dict(optional_attrs)))
where optional_attrs can contain standard_name, long_name, units as keys

Example:
extra_variables = dict(
ADJtheta = dict(dims=['k','j','i'], attrs=dict(
standard_name='Sensitivity_to_theta',
long_name='Sensitivity of cost function to theta', units='[J]/degC'))
)

Returns
-------
Expand All @@ -829,6 +850,7 @@ def _if_not_none(a, b):
else:
return a

self.var_metadata = _get_var_metadata(extra_variables=extra_variables)
user_iter_params = [iter_start, iter_stop, iter_step]
attribute_iter_params = [self.iter_start, self.iter_stop, self.iter_step]

Expand All @@ -838,7 +860,7 @@ def _if_not_none(a, b):
if iters is not None:
raise ValueError("Only `iters` or the parameters `iter_start`, `iters_stop`, "
"and `iter_step` can be provided. Both were provided")

# Otherwise we can override any missing values
iter_start = _if_not_none(iter_start, self.iter_start)
iter_stop = _if_not_none(iter_stop, self.iter_stop)
Expand All @@ -849,12 +871,12 @@ def _if_not_none(a, b):
"and `iter_step` must be defined either by the "
"model class or as argument. Instead got %r "
% iter_params)

# Otherwise try loading from the user set iters
elif iters is not None:
pass

# Now have a go at using the attribute derived iteration parameters
# Now have a go at using the attribute derived iteration parameters
elif all([a is not None for a in attribute_iter_params]):
iter_params = attribute_iter_params

Expand All @@ -867,7 +889,7 @@ def _if_not_none(a, b):
raise ValueError("The parameters `iter_start`, `iter_stop`, "
"and `iter_step`, or `iters` must be defined either by the "
"model class or as argument")

# Check the iter_start and iter_step
if iters is None:
self._check_iter_start(iter_params[0])
Expand Down Expand Up @@ -906,30 +928,30 @@ def _if_not_none(a, b):
# do separately for vertical coords on kp1_levels
grid_facets = {}
for vname in grid_varnames:
my_k_levels = k_levels if _VAR_METADATA[vname]['dims'] !=['k_p1'] else kp1_levels
my_k_levels = k_levels if self.var_metadata[vname]['dims'] !=['k_p1'] else kp1_levels
grid_facets[vname] = self._get_facet_data(vname, None, my_k_levels, k_chunksize)

# transform it into faces or latlon
data_transformers = {'faces': _all_facets_to_faces,
'latlon': _all_facets_to_latlon}

transformer = data_transformers[type]
data = transformer(data_facets, _VAR_METADATA, self.nface)
data = transformer(data_facets, self.var_metadata, self.nface)

# separate horizontal and vertical grid variables
hgrid_facets = {key: grid_facets[key]
for key in grid_varnames if not _is_vgrid(key)}
for key in grid_varnames if not self._is_vgrid(key)}
vgrid_facets = {key: grid_facets[key]
for key in grid_varnames if _is_vgrid(key)}
for key in grid_varnames if self._is_vgrid(key)}

# do not transform vertical grid variables
data.update(transformer(hgrid_facets, _VAR_METADATA, self.nface))
data.update(transformer(hgrid_facets, self.var_metadata, self.nface))
data.update(vgrid_facets)

variables = {}
gridlist = ['Zl','Zu'] if read_grid else []
for vname in varnames+grid_varnames:
meta = _VAR_METADATA[vname]
meta = self.var_metadata[vname]
dims = meta['dims']
if type=='faces':
dims = _add_face_to_dims(dims)
Expand All @@ -948,9 +970,9 @@ def _if_not_none(a, b):
if read_grid and 'RF' in grid_varnames:
ki = np.array([list(kp1_levels).index(x) for x in k_levels])
for zv,sl in zip(['Zl','Zu'],[ki,ki+1]):
variables[zv] = xr.Variable(_VAR_METADATA[zv]['dims'],
variables[zv] = xr.Variable(self.var_metadata[zv]['dims'],
data['RF'][sl],
_VAR_METADATA[zv]['attrs'])
self.var_metadata[zv]['attrs'])

ds = ds.update(variables)

Expand Down
Loading
Loading