Skip to content

Commit

Permalink
Adding CohortMethods to StemTraits and StemAllometry
Browse files Browse the repository at this point in the history
  • Loading branch information
davidorme committed Oct 24, 2024
1 parent 6980eb8 commit dbffac0
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 11 deletions.
28 changes: 22 additions & 6 deletions pyrealm/demography/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,22 +538,38 @@ def from_toml(cls, path: Path, flora: Flora) -> Community:

return cls(**file_data, flora=flora)

def add_cohorts(self, new_cohorts: Cohorts) -> None:
def drop_cohorts(self, drop_indices: NDArray[np.int_]) -> None:
"""Drop cohorts from the community.
This method drops the identified cohorts from the ``cohorts`` attribute and then
removes their data from the ``stem_traits`` and ``stem_allometry`` attributes
to match.
"""

self.cohorts.drop_cohort_data(drop_indices=drop_indices)
self.stem_traits.drop_cohort_data(drop_indices=drop_indices)
self.stem_allometry.drop_cohort_data(drop_indices=drop_indices)

def add_cohorts(self, new_data: Cohorts) -> None:
"""Add a new set of cohorts to the community.
This method extends the ``cohorts`` attribute with the new cohort data and then
also extends the ``stem_traits`` and ``stem_allometry`` to match.
Args:
new_data: An instance of :class:`~pyrealm.demography.community.Cohorts`
containing cohort data to add to the community.
"""

self.cohorts.add_cohort_data(new_cohorts)
self.cohorts.add_cohort_data(new_data=new_data)

new_stem_traits = self.flora.get_stem_traits(new_cohorts.pft_names)
self.stem_traits.add_cohort_data(new_stem_traits)
new_stem_traits = self.flora.get_stem_traits(pft_names=new_data.pft_names)
self.stem_traits.add_cohort_data(new_data=new_stem_traits)

new_stem_allometry = StemAllometry(
stem_traits=new_stem_traits, at_dbh=new_cohorts.dbh_values
stem_traits=new_stem_traits, at_dbh=new_data.dbh_values
)
self.stem_allometry.add_cohort_data(new_stem_allometry)
self.stem_allometry.add_cohort_data(new_data=new_stem_allometry)

# @classmethod
# def load_communities_from_csv(
Expand Down
12 changes: 11 additions & 1 deletion pyrealm/demography/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,15 @@ def drop_cohort_data(self, drop_indices: NDArray[np.int_]) -> None:
drop_indices: An array of integer indices to drop from each array attribute.
"""

# TODO - Probably part of tackling #317
# The delete axis=0 here is tied to the case of dropping rows from 2D
# arrays, but then I'm thinking it makes more sense to _only_ support 2D
# arrays rather than the current mixed bag of getting a 1D array when a
# single height is provided. Promoting that kind of input to 2D and then
# enforcing an identical internal structure seems better.
# - But! Trait data does not have 2 dimensions!
# - Also to check here - this can lead to empty instances, which probably
# are a thing we want, if mortality removes all cohorts.

for trait in self.array_attrs:
setattr(self, trait, np.delete(getattr(self, trait), drop_indices))
setattr(self, trait, np.delete(getattr(self, trait), drop_indices, axis=0))
4 changes: 2 additions & 2 deletions pyrealm/demography/flora.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from marshmallow.exceptions import ValidationError
from numpy.typing import NDArray

from pyrealm.demography.core import PandasExporter
from pyrealm.demography.core import CohortMethods, PandasExporter

if sys.version_info[:2] >= (3, 11):
import tomllib
Expand Down Expand Up @@ -443,7 +443,7 @@ def get_stem_traits(self, pft_names: NDArray[np.str_]) -> StemTraits:


@dataclass()
class StemTraits(PandasExporter):
class StemTraits(PandasExporter, CohortMethods):
"""A dataclass for stem traits.
This dataclass is used to provide arrays of plant functional type (PFT) traits
Expand Down
4 changes: 2 additions & 2 deletions pyrealm/demography/t_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpy.typing import NDArray

from pyrealm.core.utilities import check_input_shapes
from pyrealm.demography.core import PandasExporter
from pyrealm.demography.core import CohortMethods, PandasExporter
from pyrealm.demography.flora import Flora, StemTraits


Expand Down Expand Up @@ -660,7 +660,7 @@ def calculate_growth_increments(


@dataclass
class StemAllometry(PandasExporter):
class StemAllometry(PandasExporter, CohortMethods):
"""Calculate T Model allometric predictions across a set of stems.
This method calculate predictions of stem allometries for stem height, crown area,
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/demography/test_flora.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,34 @@ def test_StemTraits(fixture_flora):
)

assert set(instance.array_attrs) == set(stem_traits_df.columns)


def test_StemTraits_CohortMethods(fixture_flora):
"""Test the StemTraits inherited cohort methods."""

from pyrealm.demography.t_model_functions import StemTraits

# Construct some input data with duplicate PFTs by doubling the fixture_flora data
flora_df = fixture_flora.to_pandas()
args = {ky: np.concatenate([val, val]) for ky, val in flora_df.items()}

stem_traits = StemTraits(**args)

# Check failure mode
with pytest.raises(ValueError) as excep:
stem_traits.add_cohort_data(new_data=dict(a=1))

assert (
str(excep.value) == "Cannot add cohort data from an dict instance to StemTraits"
)

# Check success of adding and dropping data
# Add a copy of itself as new cohort data and check the shape
stem_traits.add_cohort_data(new_data=stem_traits)
assert stem_traits.h_max.shape == (4 * fixture_flora.n_pfts,)
assert stem_traits.h_max.sum() == 4 * flora_df["h_max"].sum()

# Remove all but the first two rows and what's left should be aligned with the
# original data
stem_traits.drop_cohort_data(drop_indices=np.arange(2, 8))
assert np.allclose(stem_traits.h_max, flora_df["h_max"])
32 changes: 32 additions & 0 deletions tests/unit/demography/test_t_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,38 @@ def test_StemAllometry(rtmodel_flora, rtmodel_data):
assert set(stem_allometry.array_attrs) == set(df.columns)


def test_StemAllometry_CohortMethods(rtmodel_flora, rtmodel_data):
"""Test the StemAllometry inherited cohort methods."""

from pyrealm.demography.t_model_functions import StemAllometry

stem_allometry = StemAllometry(
stem_traits=rtmodel_flora, at_dbh=rtmodel_data["dbh"][:, [0]]
)
check_data = stem_allometry.crown_fraction.copy()

# Check failure mode
with pytest.raises(ValueError) as excep:
stem_allometry.add_cohort_data(new_data=dict(a=1))

assert (
str(excep.value)
== "Cannot add cohort data from an dict instance to StemAllometry"
)

# Check success of adding and dropping data
n_entries = len(rtmodel_data["dbh"])
# Add a copy of itself as new cohort data and check the shape
stem_allometry.add_cohort_data(new_data=stem_allometry)
assert stem_allometry.crown_fraction.shape == (2 * n_entries, rtmodel_flora.n_pfts)
assert stem_allometry.crown_fraction.sum() == 2 * check_data.sum()

# Remove the rows from the first copy and what's left should be aligned with the
# original data
stem_allometry.drop_cohort_data(drop_indices=np.arange(n_entries))
assert np.allclose(stem_allometry.crown_fraction, check_data)


def test_StemAllocation(rtmodel_flora, rtmodel_data):
"""Test the StemAllometry class."""

Expand Down

0 comments on commit dbffac0

Please sign in to comment.