diff --git a/xmitgcm/llcreader/llcmodel.py b/xmitgcm/llcreader/llcmodel.py index 81f6349..a33dbaf 100644 --- a/xmitgcm/llcreader/llcmodel.py +++ b/xmitgcm/llcreader/llcmodel.py @@ -65,6 +65,22 @@ def _get_var_metadata(extra_variables=None): return var_metadata +def _get_variable_point(vname, dims, mask_override): + # fix for https://github.com/MITgcm/xmitgcm/issues/191 + if vname in mask_override: + return mask_override[vname] + if 'i' in dims and 'j' in dims: + point = 'c' + elif 'i_g' in dims and 'j' in dims: + point = 'w' + elif 'i' in dims and 'j_g' in dims: + point = 's' + elif 'i_g' in dims and 'j_g' in dims: + raise ValueError("Don't have masks for corner points!") + else: + raise ValueError("Variable `%s` is not a horizontal variable." % vname) + return point + def _decompress(data, mask, dtype): data_blank = np.full_like(mask, np.nan, dtype=dtype) @@ -405,7 +421,7 @@ def _chunks(l, n): def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces, - dtype, mask_override, domain, pad_before, pad_after): + dtype, mask_override, domain, pad_before, pad_after, dims): fs, path = store.get_fs_and_full_path(varname, iternum) @@ -423,7 +439,7 @@ def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces, if (store.shrunk and iternum is not None) or \ (store.shrunk_grid and iternum is None): # the store tells us whether we need a mask or not - point = _get_variable_point(varname, mask_override) + point = _get_variable_point(varname, dims, mask_override) mykey = nx if domain == 'global' else f'{domain}_{nx}' index = all_index_data[mykey][point] zgroup = store.open_mask_group() @@ -593,22 +609,6 @@ def _is_vgrid(self, vname): dims = self.var_metadata[vname]['dims'] return len(dims)==1 and dims[0][0]=='k' - def _get_variable_point(self, vname, mask_override): - # fix for https://github.com/MITgcm/xmitgcm/issues/191 - if vname in mask_override: - return mask_override[vname] - dims = self.var_metadata[vname]['dims'] - if 'i' in dims and 'j' in dims: - point = 'c' - elif 'i_g' in dims and 'j' in dims: - point = 'w' - elif 'i' in dims and 'j_g' in dims: - point = 's' - elif 'i_g' in dims and 'j_g' in dims: - raise ValueError("Don't have masks for corner points!") - else: - raise ValueError("Variable `%s` is not a horizontal variable." % vname) - return point def _get_scalars_and_vectors(self, varnames, type): @@ -701,10 +701,11 @@ def _key_and_task(n_k, these_klevels, n_iter=None, iternum=None): key = name, n_k, 0, 0, 0 else: key = name, n_iter, n_k, 0, 0, 0 + dims = self.var_metadata[varname]['dims'] task = (_get_facet_chunk, self.store, varname, iternum, nfacet, these_klevels, self.nx, self.nz, self.nface, dtype, self.mask_override, self.domain, - self.pad_before, self.pad_after) + self.pad_before, self.pad_after, dims) return key, task if iters is not None: