Skip to content

Commit

Permalink
Introduce a value range check on the numpy level (#43)
Browse files Browse the repository at this point in the history
* added range_check and tests

* simplify range check and add performance message

* Move warning to xarray wrappers

* Fix leftover bug

* Update source/aerobulk/flux.py

Co-authored-by: Paige Martin <[email protected]>

Co-authored-by: Paige Martin <[email protected]>
  • Loading branch information
jbusecke and paigem authored Jun 30, 2022
1 parent a69c96f commit 8a802fc
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 34 deletions.
102 changes: 82 additions & 20 deletions source/aerobulk/flux.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,57 @@
import warnings

import aerobulk.aerobulk.mod_aerobulk_wrap_noskin as aeronoskin
import aerobulk.aerobulk.mod_aerobulk_wrap_skin as aeroskin
import numpy as np
import xarray as xr

VALID_ALGOS = ["coare3p0", "coare3p6", "ecmwf", "ncar", "andreas"]
VALID_ALGOS_SKIN = ["coare3p0", "coare3p6", "ecmwf"]
VALID_VALUE_RANGES = {
"sst": [270, 320],
"t_zt": [180, 330],
"hum_zt": [0, 0.08],
"u_zu": [-50, 50],
"v_zu": [-50, 50],
"slp": [80000, 110000],
"rad_sw": [0, 1500],
"rad_lw": [0, 750],
}


def _check_algo(algo, valids):
if algo not in valids:
raise ValueError(f"Algorithm {algo} not valid. Choose from {valids}.")


def _check_value_range(*args):
"""Checks the input ranges for input fields"""
# parse inputs to names
args_dict = {
k: v
for k, v in zip(
["sst", "t_zt", "hum_zt", "u_zu", "v_zu", "slp", "rad_sw", "rad_lw"], args
)
}
for var, data in args_dict.items():
# check for misaligned nans
if var != "sst":
if np.isnan(data).any():
raise ValueError(
f"Found nans in {var} that do not align with nans in `sst`. Check that nans in all fields are matched."
)

# check for valid range
range = VALID_VALUE_RANGES[var]

# check that values are in range
out_of_range = ~np.logical_and(data >= range[0], data <= range[1])
if out_of_range.any():
raise ValueError(
f"Found values in {var} that are out of the valid range ({range[0]}-{range[1]})."
)


# Unshrink the data (i.e. put land NaN values back in their correct locations)
def unshrink_arr(shrunk_array, shape, ocean_index):
unshrunk_array = np.full(shape, np.nan)
Expand All @@ -20,16 +60,7 @@ def unshrink_arr(shrunk_array, shape, ocean_index):


def noskin_np(
sst,
t_zt,
hum_zt,
u_zu,
v_zu,
slp,
algo,
zt,
zu,
niter,
sst, t_zt, hum_zt, u_zu, v_zu, slp, algo, zt, zu, niter, input_range_check
):
"""Python wrapper for aerobulk without skin correction.
!ATTENTION If input not provided in correct units, will crash.
Expand Down Expand Up @@ -81,6 +112,9 @@ def noskin_np(
np.atleast_3d(a[ocean_index]) for a in (sst, t_zt, hum_zt, u_zu, v_zu, slp)
)

if input_range_check:
_check_value_range(*args_shrunk)

out_data = aeronoskin.mod_aerobulk_wrapper_noskin.aerobulk_model_noskin(
algo, zt, zu, *args_shrunk, niter
)
Expand All @@ -94,13 +128,14 @@ def skin_np(
hum_zt,
u_zu,
v_zu,
slp,
rad_sw,
rad_lw,
slp,
algo,
zt,
zu,
niter,
input_range_check,
):
"""Python wrapper for aerobulk with skin correction.
!ATTENTION If input not provided in correct units, will crash.
Expand Down Expand Up @@ -158,6 +193,9 @@ def skin_np(
for a in (sst, t_zt, hum_zt, u_zu, v_zu, slp, rad_sw, rad_lw)
)

if input_range_check:
_check_value_range(*args_shrunk)

out_data = aeroskin.mod_aerobulk_wrapper_skin.aerobulk_model_skin(
algo, zt, zu, *args_shrunk, niter
)
Expand All @@ -166,7 +204,17 @@ def skin_np(


def noskin(
sst, t_zt, hum_zt, u_zu, v_zu, slp=101000.0, algo="coare3p0", zt=2, zu=10, niter=6
sst,
t_zt,
hum_zt,
u_zu,
v_zu,
slp=101000.0,
algo="coare3p0",
zt=2,
zu=10,
niter=6,
input_range_check=True,
):
"""xarray wrapper for aerobulk without skin correction.
Expand Down Expand Up @@ -202,6 +250,9 @@ def noskin(
niter : int, optional
Number of iteration steps used in the algorithm,
by default 6
input_range_check: bool, optional
Turn on/off explicit checking of input variables for valid ranges.
On by default, but for best performance should be turned off
Returns
-------
Expand All @@ -221,6 +272,12 @@ def noskin(
sst, t_zt, hum_zt, u_zu, v_zu, slp = xr.broadcast(
sst, t_zt, hum_zt, u_zu, v_zu, slp
)
if input_range_check:
performance_msg = (
"Checking for misaligned nans and values outside of the valid range is performed by default, but reduces performance. \n"
"If you are sure your data is valid you can deactivate these checks by setting `input_range_check=False`"
)
warnings.warn(performance_msg)

out_vars = xr.apply_ufunc(
noskin_np,
Expand All @@ -234,10 +291,7 @@ def noskin(
output_core_dims=[()] * 5,
dask="parallelized",
kwargs=dict(
algo=algo,
zt=zt,
zu=zu,
niter=niter,
algo=algo, zt=zt, zu=zu, niter=niter, input_range_check=input_range_check
),
output_dtypes=[sst.dtype]
* 5, # deactivates the 1 element check which aerobulk does not like
Expand All @@ -262,6 +316,7 @@ def skin(
zt=2,
zu=10,
niter=6,
input_range_check=True,
):

"""xarray wrapper for aerobulk with skin correction.
Expand Down Expand Up @@ -302,6 +357,9 @@ def skin(
niter : int, optional
Number of iteration steps used in the algorithm,
by default 6
input_range_check: bool, optional
Turn on/off explicit checking of input variables for valid ranges.
On by default, but for best performance should be turned off
Returns
-------
Expand All @@ -325,6 +383,13 @@ def skin(
sst, t_zt, hum_zt, u_zu, v_zu, rad_sw, rad_lw, slp
)

if input_range_check:
performance_msg = (
"Checking for misaligned nans and values outside of the valid range is performed by default, but reduces performance. \n"
"If you are sure your data is valid you can deactivate these checks by setting `input_range_check=False`"
)
warnings.warn(performance_msg)

out_vars = xr.apply_ufunc(
skin_np,
sst,
Expand All @@ -339,10 +404,7 @@ def skin(
output_core_dims=[()] * 6,
dask="parallelized",
kwargs=dict(
algo=algo,
zt=zt,
zu=zu,
niter=niter,
algo=algo, zt=zt, zu=zu, niter=niter, input_range_check=input_range_check
),
output_dtypes=[sst.dtype]
* 6, # deactivates the 1 element check which aerobulk does not like
Expand Down
34 changes: 24 additions & 10 deletions tests/create_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ def create_data(
order: str = "F",
use_xr=True,
land_mask=False,
sst=290.0,
t_zt=280.0,
hum_zt=0.001,
u_zu=1.0,
v_zu=-1.0,
slp=101000.0,
rad_sw=0.000001,
rad_lw=350.0,
):
size = shape[0] * shape[1]
shape2d = (shape[0], shape[1])
Expand All @@ -37,15 +45,21 @@ def _arr(value, chunks, order):
arr = arr.chunk(chunks)
return arr

sst = _arr(290.0, chunks, order)
t_zt = _arr(280.0, chunks, order)
hum_zt = _arr(0.001, chunks, order)
u_zu = _arr(1.0, chunks, order)
v_zu = _arr(-1.0, chunks, order)
slp = _arr(101000.0, chunks, order)
rad_sw = _arr(0.000001, chunks, order)
rad_lw = _arr(350.0, chunks, order)
if skin_correction:
return sst, t_zt, hum_zt, u_zu, v_zu, rad_sw, rad_lw, slp
return tuple(
_arr(a, chunks, order)
for a in (
sst,
t_zt,
hum_zt,
u_zu,
v_zu,
slp,
rad_sw,
rad_lw,
)
)
else:
return sst, t_zt, hum_zt, u_zu, v_zu, slp
return tuple(
_arr(a, chunks, order) for a in (sst, t_zt, hum_zt, u_zu, v_zu, slp)
)
61 changes: 59 additions & 2 deletions tests/test_flux_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_land_mask(skin_correction, algo):
use_xr=False,
land_mask=True,
)
out_data = func(*args, algo, 2, 10, 6)
out_data = func(*args, algo, 2, 10, 6, input_range_check=True)

# Check the location of all NaNs is correct
for o in out_data:
Expand All @@ -47,6 +47,63 @@ def test_land_mask(skin_correction, algo):
if not np.isnan(out_data[0][index]):
single_inputs = tuple(np.atleast_3d(i[index]) for i in args)

single_outputs = func(*single_inputs, algo, 2, 10, 6)
single_outputs = func(
*single_inputs, algo, 2, 10, 6, input_range_check=True
)
for so, o in zip(single_outputs, out_data):
assert so == o[index]


@pytest.mark.parametrize(
"varname",
[
"t_zt",
"hum_zt",
"u_zu",
"v_zu",
"slp",
"rad_sw",
"rad_lw",
],
)
class Test_Range_Check:
def test_range_check_nan(self, varname):
shape = (1, 1)
args_noskin = create_data(shape, skin_correction=False, **{varname: np.nan})
args_skin = create_data(shape, skin_correction=True, **{varname: np.nan})
msg = f"Found nans in {varname} that do not align with nans in `sst`. Check that nans in all fields are matched."

if varname not in ["rad_sw", "rad_lw"]: # Test these only for skin
with pytest.raises(ValueError, match=msg):
noskin_np(*args_noskin, "ecmwf", 2, 10, 6, input_range_check=True)

with pytest.raises(ValueError, match=msg):
skin_np(*args_skin, "ecmwf", 2, 10, 6, input_range_check=True)

def test_range_check_invalid(self, varname):
invalid_values = {
"sst": 200.0,
"t_zt": 100,
"hum_zt": 0.1,
"u_zu": 60,
"v_zu": 60,
"slp": 2000,
"rad_sw": -20,
"rad_lw": 4000,
}
shape = (1, 1)
args_noskin = create_data(
shape, skin_correction=False, **{varname: invalid_values[varname]}
)
args_skin = create_data(
shape, skin_correction=True, **{varname: invalid_values[varname]}
)
# I hate regex sooo much. If someone has a nice way to just test that the varname is in here and the error message contains 'range'
# that would be amazing. This is the best I could do...
msg = r"\b(?:out of the valid range)\b"
if varname not in ["rad_sw", "rad_lw"]:
with pytest.raises(ValueError, match=str(msg)):
noskin_np(*args_noskin, "ecmwf", 2, 10, 6, input_range_check=True)

with pytest.raises(ValueError, match=str(msg)):
skin_np(*args_skin, "ecmwf", 2, 10, 6, input_range_check=True)
15 changes: 13 additions & 2 deletions tests/test_fortran.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@ def test_fortran_noskin(algo, expected):
v_zu = np.atleast_3d(0)
slp = np.atleast_3d(101000.0) # Pa
ql, qh, taux, tauy, evap = noskin_np(
sst, t_zt, hum_zt, u_zu, v_zu, slp, algo=algo, zt=2, zu=10, niter=10
sst,
t_zt,
hum_zt,
u_zu,
v_zu,
slp,
algo=algo,
zt=2,
zu=10,
niter=10,
input_range_check=True,
)
evap = evap * 3600 * 24 # convert to mm/day

Expand Down Expand Up @@ -108,13 +118,14 @@ def test_fortran_skin(algo, expected):
hum_zt,
u_zu,
v_zu,
slp,
rad_sw,
rad_lw,
slp,
algo=algo,
zt=2,
zu=10,
niter=10,
input_range_check=True,
)
evap = evap * 3600 * 24 # convert to mm/day
t_s = t_s - rt0
Expand Down

0 comments on commit 8a802fc

Please sign in to comment.