Skip to content

Commit

Permalink
Fixing broadcast issue in get_allometries, fix up failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidorme committed Sep 26, 2024
1 parent 5b61e22 commit c415feb
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 33 deletions.
59 changes: 35 additions & 24 deletions pyrealm/demography/t_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,58 +625,69 @@ def calculate_t_model_allometry(
allometry values.
"""

stem_data = {"dbh": dbh}

stem_data["stem_height"] = calculate_heights(
stem_height = calculate_heights(
h_max=pft_data["h_max"],
a_hd=pft_data["a_hd"],
dbh=stem_data["dbh"],
dbh=dbh,
)

stem_data["crown_area"] = calculate_crown_areas(
# Broadcast dbh to shape of stem height to get congruent shapes
dbh = np.broadcast_to(dbh, stem_height.shape)

crown_area = calculate_crown_areas(
ca_ratio=pft_data["ca_ratio"],
a_hd=pft_data["a_hd"],
dbh=stem_data["dbh"],
stem_height=stem_data["stem_height"],
dbh=dbh,
stem_height=stem_height,
)

stem_data["crown_fraction"] = calculate_crown_fractions(
crown_fraction = calculate_crown_fractions(
a_hd=pft_data["a_hd"],
dbh=stem_data["dbh"],
stem_height=stem_data["stem_height"],
dbh=dbh,
stem_height=stem_height,
)

stem_data["stem_mass"] = calculate_stem_masses(
stem_mass = calculate_stem_masses(
rho_s=pft_data["rho_s"],
dbh=stem_data["dbh"],
stem_height=stem_data["stem_height"],
dbh=dbh,
stem_height=stem_height,
)

stem_data["foliage_mass"] = calculate_foliage_masses(
foliage_mass = calculate_foliage_masses(
sla=pft_data["sla"],
lai=pft_data["lai"],
crown_area=stem_data["crown_area"],
crown_area=crown_area,
)

stem_data["sapwood_mass"] = calculate_sapwood_masses(
sapwood_mass = calculate_sapwood_masses(
rho_s=pft_data["rho_s"],
ca_ratio=pft_data["ca_ratio"],
stem_height=stem_data["stem_height"],
crown_area=stem_data["crown_area"],
crown_fraction=stem_data["crown_fraction"],
stem_height=stem_height,
crown_area=crown_area,
crown_fraction=crown_fraction,
)

stem_data["canopy_r0"] = calculate_canopy_r0(
canopy_r0 = calculate_canopy_r0(
q_m=pft_data["q_m"],
crown_area=stem_data["crown_area"],
crown_area=crown_area,
)

stem_data["canopy_z_max"] = calculate_canopy_z_max(
canopy_z_max = calculate_canopy_z_max(
z_max_prop=pft_data["z_max_prop"],
stem_height=stem_data["stem_height"],
stem_height=stem_height,
)

return stem_data
return dict(
dbh=dbh,
stem_height=stem_height,
crown_area=crown_area,
crown_fraction=crown_fraction,
stem_mass=stem_mass,
foliage_mass=foliage_mass,
sapwood_mass=sapwood_mass,
canopy_r0=canopy_r0,
canopy_z_max=canopy_z_max,
)


def calculate_t_model_allocation(
Expand Down
38 changes: 29 additions & 9 deletions tests/unit/demography/test_flora.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,12 @@ def test_Flora_get_allometries_dbh_against_rtmodel(rtmodel_data, rtmodel_flora):
correct values. So the shapes go (3,) x (6, 1) -> (6,3)
"""

result = rtmodel_flora.get_allometries(dbh=rtmodel_data["dbh"][:, 0])
result = rtmodel_flora.get_allometries(dbh=rtmodel_data["dbh"][:, [0]])

for key, value in result.items():
assert np.allclose(value, rtmodel_data[key])
# Skip canopy shape allometries that are not in the original T Model
if key not in ("canopy_r0", "canopy_z_max"):
assert np.allclose(value, rtmodel_data[key])


def test_Flora_get_allometries_stem_height_against_rtmodel(rtmodel_data, rtmodel_flora):
Expand All @@ -291,11 +293,13 @@ def test_Flora_get_allometries_stem_height_against_rtmodel(rtmodel_data, rtmodel
for idx, (name, pft) in enumerate(rtmodel_flora.items()):
single_pft_flora = Flora(pfts=[pft])
result = single_pft_flora.get_allometries(
stem_height=rtmodel_data["stem_height"][:, idx]
stem_height=rtmodel_data["stem_height"][:, [idx]]
)

for key, value in result.items():
assert np.allclose(value, rtmodel_data[key][:, [idx]])
# Skip canopy shape allometries that are not in the original T Model
if key not in ("canopy_r0", "canopy_z_max"):
assert np.allclose(value, rtmodel_data[key][:, [idx]])


@pytest.mark.parametrize(
Expand All @@ -319,33 +323,49 @@ def test_Flora_get_allometries_stem_height_against_rtmodel(rtmodel_data, rtmodel
np.array([[0.1, 0.2, 0.3]]),
None,
pytest.raises(ValueError),
"DBH must be a one dimensional array",
"The z argument is two dimensional (shape: (1, 3)) "
"but is not a column array.", # TODO - fix to DBH not z
id="fail_dbh_not_1D",
),
pytest.param(
None,
np.array([[10, 20, 30]]),
pytest.raises(ValueError),
"Stem heights must be a one dimensional array",
"The z argument is two dimensional (shape: (1, 3)) "
"but is not a column array.", # TODO - fix to stem_height not z
id="fail_stem_height_not_1D",
),
pytest.param(
np.array([0.1, 0.2, 0.3]),
None,
does_not_raise(),
None,
id="ok_with_dbh",
id="ok_with_dbh_as_row",
),
pytest.param(
None,
np.array([5, 10, 15]),
does_not_raise(),
None,
id="ok_with_stem_heights",
id="ok_with_stem_heights_as_row",
),
pytest.param(
np.array([[0.1], [0.2], [0.3]]),
None,
np.array([0, 5, 10, 15, 45.33, 1000]),
does_not_raise(),
None,
id="ok_with_dbh_as_col",
),
pytest.param(
None,
np.array([[5], [10], [15]]),
does_not_raise(),
None,
id="ok_with_stem_heights_as_col",
),
pytest.param(
None,
np.array([[0], [5], [10], [15], [45.33], [1000]]),
does_not_raise(),
None,
id="ok_with_edgy_stem_heights",
Expand Down

0 comments on commit c415feb

Please sign in to comment.