diff --git a/xmitgcm/llcreader/llcmodel.py b/xmitgcm/llcreader/llcmodel.py index edce045..a33dbaf 100644 --- a/xmitgcm/llcreader/llcmodel.py +++ b/xmitgcm/llcreader/llcmodel.py @@ -37,7 +37,7 @@ def _get_grid_metadata(): return grid_metadata -def _get_var_metadata(): +def _get_var_metadata(extra_variables=None): # The LLC run data comes with zero metadata. So we import metadata from # the xmitgcm package. from ..variables import state_variables, package_state_variables @@ -50,6 +50,8 @@ def _get_var_metadata(): var_metadata = state_variables.copy() var_metadata.update(package_state_variables) var_metadata.update(available_diags) + if extra_variables is not None: + var_metadata.update(extra_variables) # even the file names from the LLC data differ from standard MITgcm output aliases = {'Eta': 'ETAN', 'PhiBot': 'PHIBOT', 'Salt': 'SALT', @@ -62,18 +64,11 @@ def _get_var_metadata(): return var_metadata -_VAR_METADATA = _get_var_metadata() -def _is_vgrid(vname): - # check for 1d, vertical grid variables - dims = _VAR_METADATA[vname]['dims'] - return len(dims)==1 and dims[0][0]=='k' - -def _get_variable_point(vname, mask_override): +def _get_variable_point(vname, dims, mask_override): # fix for https://github.com/MITgcm/xmitgcm/issues/191 if vname in mask_override: return mask_override[vname] - dims = _VAR_METADATA[vname]['dims'] if 'i' in dims and 'j' in dims: point = 'c' elif 'i_g' in dims and 'j' in dims: @@ -86,30 +81,6 @@ def _get_variable_point(vname, mask_override): raise ValueError("Variable `%s` is not a horizontal variable." % vname) return point -def _get_scalars_and_vectors(varnames, type): - - for vname in varnames: - if vname not in _VAR_METADATA: - raise ValueError("Varname `%s` not found in metadata." % vname) - - if type != 'latlon': - return varnames, [] - - scalars = [] - vector_pairs = [] - for vname in varnames: - meta = _VAR_METADATA[vname] - try: - mate = meta['attrs']['mate'] - if mate not in varnames: - raise ValueError("Vector pairs are required to create " - "latlon type datasets. Varname `%s` is " - "missing its vector mate `%s`" - % vname, mate) - vector_pairs.append((vname, mate)) - varnames.remove(mate) - except KeyError: - scalars.append(vname) def _decompress(data, mask, dtype): data_blank = np.full_like(mask, np.nan, dtype=dtype) @@ -450,7 +421,7 @@ def _chunks(l, n): def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces, - dtype, mask_override, domain, pad_before, pad_after): + dtype, mask_override, domain, pad_before, pad_after, dims): fs, path = store.get_fs_and_full_path(varname, iternum) @@ -468,7 +439,7 @@ def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces, if (store.shrunk and iternum is not None) or \ (store.shrunk_grid and iternum is None): # the store tells us whether we need a mask or not - point = _get_variable_point(varname, mask_override) + point = _get_variable_point(varname, dims, mask_override) mykey = nx if domain == 'global' else f'{domain}_{nx}' index = all_index_data[mykey][point] zgroup = store.open_mask_group() @@ -594,6 +565,7 @@ class BaseLLCModel: varnames = [] grid_varnames = [] mask_override = {} + var_metadata = None domain = 'global' pad_before = [0]*_nfacets pad_after = [0]*_nfacets @@ -632,6 +604,37 @@ def _dtype(self,varname=None): elif isinstance(self.dtype,dict): return np.dtype(self.dtype[varname]) + def _is_vgrid(self, vname): + # check for 1d, vertical grid variables + dims = self.var_metadata[vname]['dims'] + return len(dims)==1 and dims[0][0]=='k' + + + def _get_scalars_and_vectors(self, varnames, type): + + for vname in varnames: + if vname not in self.var_metadata: + raise ValueError("Varname `%s` not found in metadata." % vname) + + if type != 'latlon': + return varnames, [] + + scalars = [] + vector_pairs = [] + for vname in varnames: + meta = self.var_metadata[vname] + try: + mate = meta['attrs']['mate'] + if mate not in varnames: + raise ValueError("Vector pairs are required to create " + "latlon type datasets. Varname `%s` is " + "missing its vector mate `%s`" + % vname, mate) + vector_pairs.append((vname, mate)) + varnames.remove(mate) + except KeyError: + scalars.append(vname) + def _get_kp1_levels(self,k_levels): # determine kp1 levels # get borders to all k (center) levels @@ -698,10 +701,11 @@ def _key_and_task(n_k, these_klevels, n_iter=None, iternum=None): key = name, n_k, 0, 0, 0 else: key = name, n_iter, n_k, 0, 0, 0 + dims = self.var_metadata[varname]['dims'] task = (_get_facet_chunk, self.store, varname, iternum, nfacet, these_klevels, self.nx, self.nz, self.nface, dtype, self.mask_override, self.domain, - self.pad_before, self.pad_after) + self.pad_before, self.pad_after, dims) return key, task if iters is not None: @@ -729,7 +733,7 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize): name = '-'.join([varname, token]) dtype = self._dtype(varname) - nz = self.nz if _VAR_METADATA[varname]['dims'] != ['k_p1'] else self.nz+1 + nz = self.nz if self.var_metadata[varname]['dims'] != ['k_p1'] else self.nz+1 task = (_get_1d_chunk, self.store, varname, list(klevels), nz, dtype) @@ -740,12 +744,12 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize): def _get_facet_data(self, varname, iters, klevels, k_chunksize): # needs facets to be outer index of nested lists - dims = _VAR_METADATA[varname]['dims'] + dims = self.var_metadata[varname]['dims'] if len(dims)==2: klevels = [0,] - if _is_vgrid(varname): + if self._is_vgrid(varname): data_facets = self._dask_array_vgrid(varname,klevels,k_chunksize) else: data_facets = [self._dask_array(nfacet, varname, iters, klevels, k_chunksize) @@ -776,18 +780,19 @@ def _check_iters(self, iters): if not set(iters) <= set(self.iters): msg = "Some requested iterations may not exist, you may need to change 'iters'" warnings.warn(msg, RuntimeWarning) - + elif self.iter_start is not None and self.iter_step is not None: for iter in iters: if (iter - self.iter_start) % self.iter_step: msg = "Some requested iterations may not exist, you may need to change 'iters'" warnings.warn(msg, RuntimeWarning) break - + def get_dataset(self, varnames=None, iter_start=None, iter_stop=None, iter_step=None, iters=None, k_levels=None, k_chunksize=1, - type='faces', read_grid=True, grid_vars_to_coords=True): + type='faces', read_grid=True, grid_vars_to_coords=True, + extra_variables=None): """ Create an xarray Dataset object for this model. @@ -817,6 +822,22 @@ def get_dataset(self, varnames=None, iter_start=None, iter_stop=None, Whether to read the grid info grid_vars_to_coords : bool, optional Whether to promote grid variables to coordinate status + extra_variables : dict, optional + Allow to pass variables not listed in the variables.py + or in available_diagnostics.log. + extra_variables must be a dict containing the variable names as keys with + the corresponging values being a dict with the keys being dims and attrs. + + Syntax: + extra_variables = dict(varname = dict(dims=list_of_dims, attrs=dict(optional_attrs))) + where optional_attrs can contain standard_name, long_name, units as keys + + Example: + extra_variables = dict( + ADJtheta = dict(dims=['k','j','i'], attrs=dict( + standard_name='Sensitivity_to_theta', + long_name='Sensitivity of cost function to theta', units='[J]/degC')) + ) Returns ------- @@ -829,6 +850,7 @@ def _if_not_none(a, b): else: return a + self.var_metadata = _get_var_metadata(extra_variables=extra_variables) user_iter_params = [iter_start, iter_stop, iter_step] attribute_iter_params = [self.iter_start, self.iter_stop, self.iter_step] @@ -838,7 +860,7 @@ def _if_not_none(a, b): if iters is not None: raise ValueError("Only `iters` or the parameters `iter_start`, `iters_stop`, " "and `iter_step` can be provided. Both were provided") - + # Otherwise we can override any missing values iter_start = _if_not_none(iter_start, self.iter_start) iter_stop = _if_not_none(iter_stop, self.iter_stop) @@ -849,12 +871,12 @@ def _if_not_none(a, b): "and `iter_step` must be defined either by the " "model class or as argument. Instead got %r " % iter_params) - + # Otherwise try loading from the user set iters elif iters is not None: pass - # Now have a go at using the attribute derived iteration parameters + # Now have a go at using the attribute derived iteration parameters elif all([a is not None for a in attribute_iter_params]): iter_params = attribute_iter_params @@ -867,7 +889,7 @@ def _if_not_none(a, b): raise ValueError("The parameters `iter_start`, `iter_stop`, " "and `iter_step`, or `iters` must be defined either by the " "model class or as argument") - + # Check the iter_start and iter_step if iters is None: self._check_iter_start(iter_params[0]) @@ -906,7 +928,7 @@ def _if_not_none(a, b): # do separately for vertical coords on kp1_levels grid_facets = {} for vname in grid_varnames: - my_k_levels = k_levels if _VAR_METADATA[vname]['dims'] !=['k_p1'] else kp1_levels + my_k_levels = k_levels if self.var_metadata[vname]['dims'] !=['k_p1'] else kp1_levels grid_facets[vname] = self._get_facet_data(vname, None, my_k_levels, k_chunksize) # transform it into faces or latlon @@ -914,22 +936,22 @@ def _if_not_none(a, b): 'latlon': _all_facets_to_latlon} transformer = data_transformers[type] - data = transformer(data_facets, _VAR_METADATA, self.nface) + data = transformer(data_facets, self.var_metadata, self.nface) # separate horizontal and vertical grid variables hgrid_facets = {key: grid_facets[key] - for key in grid_varnames if not _is_vgrid(key)} + for key in grid_varnames if not self._is_vgrid(key)} vgrid_facets = {key: grid_facets[key] - for key in grid_varnames if _is_vgrid(key)} + for key in grid_varnames if self._is_vgrid(key)} # do not transform vertical grid variables - data.update(transformer(hgrid_facets, _VAR_METADATA, self.nface)) + data.update(transformer(hgrid_facets, self.var_metadata, self.nface)) data.update(vgrid_facets) variables = {} gridlist = ['Zl','Zu'] if read_grid else [] for vname in varnames+grid_varnames: - meta = _VAR_METADATA[vname] + meta = self.var_metadata[vname] dims = meta['dims'] if type=='faces': dims = _add_face_to_dims(dims) @@ -948,9 +970,9 @@ def _if_not_none(a, b): if read_grid and 'RF' in grid_varnames: ki = np.array([list(kp1_levels).index(x) for x in k_levels]) for zv,sl in zip(['Zl','Zu'],[ki,ki+1]): - variables[zv] = xr.Variable(_VAR_METADATA[zv]['dims'], + variables[zv] = xr.Variable(self.var_metadata[zv]['dims'], data['RF'][sl], - _VAR_METADATA[zv]['attrs']) + self.var_metadata[zv]['attrs']) ds = ds.update(variables) diff --git a/xmitgcm/mds_store.py b/xmitgcm/mds_store.py index ea3730f..665f9b7 100644 --- a/xmitgcm/mds_store.py +++ b/xmitgcm/mds_store.py @@ -59,7 +59,8 @@ def open_mdsdataset(data_dir, grid_dir=None, ignore_unknown_vars=False, default_dtype=None, nx=None, ny=None, nz=None, llc_method="smallchunks", extra_metadata=None, - extra_variables=None): + extra_variables=None, + custom_grid_variables=None): """Open MITgcm-style mds (.data / .meta) file output as xarray datset. Parameters @@ -148,6 +149,8 @@ def open_mdsdataset(data_dir, grid_dir=None, standard_name='Sensitivity_to_theta', long_name='Sensitivity of cost function to theta', units='[J]/degC')) ) + custom_grid_variables : dict, optional + Similar to extra_variables, but these files don't have a time stamp. Returns @@ -235,7 +238,8 @@ def open_mdsdataset(data_dir, grid_dir=None, default_dtype=default_dtype, nx=nx, ny=ny, nz=nz, llc_method=llc_method, levels=levels, extra_metadata=extra_metadata, - extra_variables=extra_variables) + extra_variables=extra_variables, + custom_grid_variables=custom_grid_variables) datasets = [open_mdsdataset( data_dir, iters=iternum, read_grid=False, **kwargs) for iternum in iters] @@ -291,7 +295,8 @@ def open_mdsdataset(data_dir, grid_dir=None, default_dtype=default_dtype, nx=nx, ny=ny, nz=nz, llc_method=llc_method, levels=levels, extra_metadata=extra_metadata, - extra_variables=extra_variables) + extra_variables=extra_variables, + custom_grid_variables=custom_grid_variables) ds = xr.Dataset.load_store(store) if swap_dims: @@ -376,7 +381,8 @@ def __init__(self, data_dir, grid_dir=None, default_dtype=np.dtype('f4'), nx=None, ny=None, nz=None, llc_method="smallchunks", levels=None, extra_metadata=None, - extra_variables=None): + extra_variables=None, + custom_grid_variables=None): """ This is not a user-facing class. See open_mdsdataset for argument documentation. The only ones which are distinct are. @@ -401,6 +407,7 @@ def __init__(self, data_dir, grid_dir=None, self.data_dir = data_dir self.grid_dir = grid_dir if (grid_dir is not None) else data_dir self.extra_variables = extra_variables + self.custom_grid_variables = custom_grid_variables self._ignore_unknown_vars = ignore_unknown_vars # The endianness of the files @@ -573,7 +580,8 @@ def __init__(self, data_dir, grid_dir=None, # build lookup tables for variable metadata self._all_grid_variables = _get_all_grid_variables(self.geometry, self.grid_dir, - self.layers) + self.layers, + self.custom_grid_variables) self._all_data_variables = _get_all_data_variables(self.data_dir, self.grid_dir, self.layers, @@ -831,7 +839,7 @@ def _guess_layers(data_dir): return all_layers -def _get_all_grid_variables(geometry, grid_dir=None, layers={}): +def _get_all_grid_variables(geometry, grid_dir=None, layers={}, custom_grid_variables=None): """"Put all the relevant grid metadata into one big dictionary.""" possible_hcoords = {'cartesian': horizontal_coordinates_cartesian, 'llc': horizontal_coordinates_llc, @@ -841,7 +849,7 @@ def _get_all_grid_variables(geometry, grid_dir=None, layers={}): hcoords = possible_hcoords[geometry] # look for extra variables, if they exist in grid_dir - extravars = _get_extra_grid_variables(grid_dir) if grid_dir is not None else {} + extravars = _get_extra_grid_variables(grid_dir, custom_grid_variables=custom_grid_variables) if grid_dir is not None else {} allvars = [hcoords, vertical_coordinates, horizontal_grid_variables, vertical_grid_variables, volume_grid_variables, mask_variables, @@ -856,17 +864,20 @@ def _get_all_grid_variables(geometry, grid_dir=None, layers={}): return metadata -def _get_extra_grid_variables(grid_dir): +def _get_extra_grid_variables(grid_dir, custom_grid_variables): """Scan a directory and return all file prefixes for extra grid files. Then return the variable information for each of these""" extra_grid = {} + if custom_grid_variables is not None: + extra_grid_variables.update(custom_grid_variables) + fnames = dict([[val['filename'],key] for key,val in extra_grid_variables.items() if 'filename' in val]) all_datafiles = listdir_endswith(grid_dir, '.data') for f in all_datafiles: prefix = os.path.split(f[:-5])[-1] - # Only consider what we find that matches extra_grid_vars + # Only consider what we find that matches extra/custom_grid_vars if prefix in extra_grid_variables: extra_grid[prefix] = extra_grid_variables[prefix] elif prefix in fnames: diff --git a/xmitgcm/test/test_mds_store.py b/xmitgcm/test/test_mds_store.py index dc7cbe2..25149cc 100644 --- a/xmitgcm/test/test_mds_store.py +++ b/xmitgcm/test/test_mds_store.py @@ -608,6 +608,41 @@ def test_extra_variables(all_mds_datadirs): mate = ds[var].attrs['mate'] assert ds[mate].attrs['mate'] == var +def test_custom_grid_variables(all_mds_datadirs): + """Test that open_mdsdataset reads custom grid variables (i.e. no time stamp) correctly""" + dirname, expected = all_mds_datadirs + + custom_grid_variables = { + "iamgridC" : { + "dims" : ["k", "j", "i"], "attrs": {}, + }, + "iamgridW" : { + "dims" : ["k", "j", "i_g"], "attrs": {}, + }, + "iamgridS" : { + "dims" : ["k", "j_g", "i"], "attrs": {}, + }, + } + + # copy hFac to our new grid variable ... + for suffix in ["C", "W", "S"]: + for ext in [".meta", ".data"]: + fname_in = os.path.join(dirname, f"hFac{suffix}{ext}") + fname_out= os.path.join(dirname, f"iamgrid{suffix}{ext}") + copyfile(fname_in, fname_out) + + ds = xmitgcm.open_mdsdataset( + dirname, + read_grid=True, + iters=None, + geometry=expected["geometry"], + prefix=list(custom_grid_variables.keys()), + custom_grid_variables=custom_grid_variables) + + for var in custom_grid_variables.keys(): + assert var in ds + assert var in ds.coords + def test_mask_values(all_mds_datadirs): """Test that open_mdsdataset generates binary masks with correct values""" diff --git a/xmitgcm/test/test_xmitgcm_common.py b/xmitgcm/test/test_xmitgcm_common.py index 5c7cbd8..3d44e8e 100644 --- a/xmitgcm/test/test_xmitgcm_common.py +++ b/xmitgcm/test/test_xmitgcm_common.py @@ -278,7 +278,7 @@ def file_md5_checksum(fname): # find the tar archive in the test directory # http://stackoverflow.com/questions/29627341/pytest-where-to-store-expected-data -@pytest.fixture(scope='module', params=_experiments.keys()) +@pytest.fixture(scope='function', params=_experiments.keys()) def all_mds_datadirs(tmpdir_factory, request): return setup_mds_dir(tmpdir_factory, request, _experiments)