diff --git a/pyrealm/demography/t_model_functions.py b/pyrealm/demography/t_model_functions.py index 8a35f029..30a36317 100644 --- a/pyrealm/demography/t_model_functions.py +++ b/pyrealm/demography/t_model_functions.py @@ -625,58 +625,69 @@ def calculate_t_model_allometry( allometry values. """ - stem_data = {"dbh": dbh} - - stem_data["stem_height"] = calculate_heights( + stem_height = calculate_heights( h_max=pft_data["h_max"], a_hd=pft_data["a_hd"], - dbh=stem_data["dbh"], + dbh=dbh, ) - stem_data["crown_area"] = calculate_crown_areas( + # Broadcast dbh to shape of stem height to get congruent shapes + dbh = np.broadcast_to(dbh, stem_height.shape) + + crown_area = calculate_crown_areas( ca_ratio=pft_data["ca_ratio"], a_hd=pft_data["a_hd"], - dbh=stem_data["dbh"], - stem_height=stem_data["stem_height"], + dbh=dbh, + stem_height=stem_height, ) - stem_data["crown_fraction"] = calculate_crown_fractions( + crown_fraction = calculate_crown_fractions( a_hd=pft_data["a_hd"], - dbh=stem_data["dbh"], - stem_height=stem_data["stem_height"], + dbh=dbh, + stem_height=stem_height, ) - stem_data["stem_mass"] = calculate_stem_masses( + stem_mass = calculate_stem_masses( rho_s=pft_data["rho_s"], - dbh=stem_data["dbh"], - stem_height=stem_data["stem_height"], + dbh=dbh, + stem_height=stem_height, ) - stem_data["foliage_mass"] = calculate_foliage_masses( + foliage_mass = calculate_foliage_masses( sla=pft_data["sla"], lai=pft_data["lai"], - crown_area=stem_data["crown_area"], + crown_area=crown_area, ) - stem_data["sapwood_mass"] = calculate_sapwood_masses( + sapwood_mass = calculate_sapwood_masses( rho_s=pft_data["rho_s"], ca_ratio=pft_data["ca_ratio"], - stem_height=stem_data["stem_height"], - crown_area=stem_data["crown_area"], - crown_fraction=stem_data["crown_fraction"], + stem_height=stem_height, + crown_area=crown_area, + crown_fraction=crown_fraction, ) - stem_data["canopy_r0"] = calculate_canopy_r0( + canopy_r0 = calculate_canopy_r0( q_m=pft_data["q_m"], - crown_area=stem_data["crown_area"], + crown_area=crown_area, ) - stem_data["canopy_z_max"] = calculate_canopy_z_max( + canopy_z_max = calculate_canopy_z_max( z_max_prop=pft_data["z_max_prop"], - stem_height=stem_data["stem_height"], + stem_height=stem_height, ) - return stem_data + return dict( + dbh=dbh, + stem_height=stem_height, + crown_area=crown_area, + crown_fraction=crown_fraction, + stem_mass=stem_mass, + foliage_mass=foliage_mass, + sapwood_mass=sapwood_mass, + canopy_r0=canopy_r0, + canopy_z_max=canopy_z_max, + ) def calculate_t_model_allocation( diff --git a/tests/unit/demography/test_flora.py b/tests/unit/demography/test_flora.py index 317f6cae..04a84e33 100644 --- a/tests/unit/demography/test_flora.py +++ b/tests/unit/demography/test_flora.py @@ -273,10 +273,12 @@ def test_Flora_get_allometries_dbh_against_rtmodel(rtmodel_data, rtmodel_flora): correct values. So the shapes go (3,) x (6, 1) -> (6,3) """ - result = rtmodel_flora.get_allometries(dbh=rtmodel_data["dbh"][:, 0]) + result = rtmodel_flora.get_allometries(dbh=rtmodel_data["dbh"][:, [0]]) for key, value in result.items(): - assert np.allclose(value, rtmodel_data[key]) + # Skip canopy shape allometries that are not in the original T Model + if key not in ("canopy_r0", "canopy_z_max"): + assert np.allclose(value, rtmodel_data[key]) def test_Flora_get_allometries_stem_height_against_rtmodel(rtmodel_data, rtmodel_flora): @@ -291,11 +293,13 @@ def test_Flora_get_allometries_stem_height_against_rtmodel(rtmodel_data, rtmodel for idx, (name, pft) in enumerate(rtmodel_flora.items()): single_pft_flora = Flora(pfts=[pft]) result = single_pft_flora.get_allometries( - stem_height=rtmodel_data["stem_height"][:, idx] + stem_height=rtmodel_data["stem_height"][:, [idx]] ) for key, value in result.items(): - assert np.allclose(value, rtmodel_data[key][:, [idx]]) + # Skip canopy shape allometries that are not in the original T Model + if key not in ("canopy_r0", "canopy_z_max"): + assert np.allclose(value, rtmodel_data[key][:, [idx]]) @pytest.mark.parametrize( @@ -319,14 +323,16 @@ def test_Flora_get_allometries_stem_height_against_rtmodel(rtmodel_data, rtmodel np.array([[0.1, 0.2, 0.3]]), None, pytest.raises(ValueError), - "DBH must be a one dimensional array", + "The z argument is two dimensional (shape: (1, 3)) " + "but is not a column array.", # TODO - fix to DBH not z id="fail_dbh_not_1D", ), pytest.param( None, np.array([[10, 20, 30]]), pytest.raises(ValueError), - "Stem heights must be a one dimensional array", + "The z argument is two dimensional (shape: (1, 3)) " + "but is not a column array.", # TODO - fix to stem_height not z id="fail_stem_height_not_1D", ), pytest.param( @@ -334,18 +340,32 @@ def test_Flora_get_allometries_stem_height_against_rtmodel(rtmodel_data, rtmodel None, does_not_raise(), None, - id="ok_with_dbh", + id="ok_with_dbh_as_row", ), pytest.param( None, np.array([5, 10, 15]), does_not_raise(), None, - id="ok_with_stem_heights", + id="ok_with_stem_heights_as_row", ), pytest.param( + np.array([[0.1], [0.2], [0.3]]), None, - np.array([0, 5, 10, 15, 45.33, 1000]), + does_not_raise(), + None, + id="ok_with_dbh_as_col", + ), + pytest.param( + None, + np.array([[5], [10], [15]]), + does_not_raise(), + None, + id="ok_with_stem_heights_as_col", + ), + pytest.param( + None, + np.array([[0], [5], [10], [15], [45.33], [1000]]), does_not_raise(), None, id="ok_with_edgy_stem_heights",