Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stats calculations for new calculated_stats_ container #468

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d00133e
Implement calcluated stats
TinyMarsh Jul 3, 2024
ed573ed
Remove unused methods for now
TinyMarsh Jul 4, 2024
3b259be
Use size_t as value type
TinyMarsh Jul 17, 2024
891b71d
Simplify function by splitting into smaller functions
TinyMarsh Jul 17, 2024
155757c
Simplify finding channel index
TinyMarsh Jul 17, 2024
f915dcc
Remove unreachable throw
TinyMarsh Jul 17, 2024
88ca8ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
acf8067
Remove unused variable
TinyMarsh Jul 17, 2024
52e23d9
static_cast to avoid comparing integers of different signs
TinyMarsh Jul 17, 2024
e4656e2
Add test for calculate_index function
TinyMarsh Aug 15, 2024
dfc7e12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
fa85412
Move simulation logic to shared location
TinyMarsh Aug 17, 2024
34d625a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
a7d8000
Amend test fixture and analysis_module code for running test
TinyMarsh Aug 31, 2024
85f23a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2024
9070c85
Make clang-tidy happy
TinyMarsh Sep 5, 2024
97a8ca5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
2ce17d8
Add missing #include
alexdewar Aug 29, 2024
ed0bcba
Add necessary include
TinyMarsh Sep 6, 2024
e8066f1
Cast as uint to avoid warning
TinyMarsh Sep 6, 2024
2f55b9f
Remove unused RiskFactorData code
TinyMarsh Oct 7, 2024
09bbeb5
Increase population diversity, multiple tests
TinyMarsh Oct 8, 2024
c99d6cf
Fix logic and implement feedback
TinyMarsh Oct 9, 2024
b5c1fd5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
ddc5198
Remove commented code and set context time
TinyMarsh Oct 11, 2024
e74050d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 11, 2024
49227a9
Static cast to uint
TinyMarsh Oct 11, 2024
6c592d3
Merge pull request #495 from imperialCHEPI/calculate_index_test
TinyMarsh Oct 14, 2024
ff7fa5a
Merge branch 'main' into extend_analysis
TinyMarsh Nov 15, 2024
4997ac4
clang-format fix
TinyMarsh Nov 15, 2024
667200c
Dereference config pointer
TinyMarsh Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 132 additions & 22 deletions src/HealthGPS/analysis_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,36 +319,88 @@ DALYsIndicator AnalysisModule::calculate_dalys(Population &population, unsigned
}

void AnalysisModule::calculate_population_statistics(RuntimeContext &context) {
jamesturner246 marked this conversation as resolved.
Show resolved Hide resolved
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
size_t num_factors_to_calculate =
context.mapping().entries().size() - factors_to_calculate_.size();

auto current_time = static_cast<unsigned int>(context.time_now());

for (const auto &person : context.population()) {
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}
// First let's fetch the correct `calculated_stats_` bin index for this person
size_t index = calculate_index(person);

// Now we can add the calculated stats for this person to the correct index
if (!person.is_active()) {
if (!person.is_alive() && person.time_of_death() == current_time) {
calculated_stats_[index + get_channel_index("deaths")]++;
float expected_life =
definition_.life_expectancy().at(context.time_now(), person.gender);
double yll = std::max(expected_life - person.age, 0.0f) * DALY_UNITS;
calculated_stats_[index + get_channel_index("mean_yll")] += yll;
calculated_stats_[index + get_channel_index("mean_daly")] += yll;
}

if (person.has_emigrated() && person.time_of_migration() == current_time) {
calculated_stats_[index + get_channel_index("emigrations")]++;
}

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(),
size_t{1}, std::multiplies<>());
index += bin_indices[i] * accumulated_bins * num_factors_to_calculate;
continue;
}
index += bin_indices.back() * num_factors_to_calculate;

// Now we can add the values of the factors that are not in factors_to_calculate_
calculated_stats_[index + get_channel_index("count")]++;

for (const auto &factor : context.mapping().entries()) {
if (std::find(factors_to_calculate_.cbegin(), factors_to_calculate_.cend(),
factor.key()) == factors_to_calculate_.cend()) {
calculated_stats_[index++] += person.get_risk_factor_value(factor.key());
double value = person.get_risk_factor_value(factor.key());
calculated_stats_[index + get_channel_index("mean_" + factor.key().to_string())] +=
value;
}

for (const auto &[disease_name, disease_state] : person.diseases) {
if (disease_state.status == DiseaseStatus::active) {
calculated_stats_[index +
get_channel_index("prevalence_" + disease_name.to_string())]++;
if (disease_state.start_time == context.time_now()) {
calculated_stats_[index +
get_channel_index("incidence_" + disease_name.to_string())]++;
}
}
}

double dw = calculate_disability_weight(person);
double yld = dw * DALY_UNITS;
calculated_stats_[index + get_channel_index("mean_yld")] += yld;
calculated_stats_[index + get_channel_index("mean_daly")] += yld;

classify_weight(person);
}

// For each bin in the calculated stats...
for (size_t i = 0; i < calculated_stats_.size(); i += channels_.size()) {
double count_F = calculated_stats_[i + get_channel_index("count")];
double count_M = calculated_stats_[i + get_channel_index("count")];
double deaths_F = calculated_stats_[i + get_channel_index("deaths")];
double deaths_M = calculated_stats_[i + get_channel_index("deaths")];

// Calculate in-place factor averages.
for (const auto &factor : context.mapping().entries()) {
calculated_stats_[i + get_channel_index("mean_" + factor.key().to_string())] /= count_F;
calculated_stats_[i + get_channel_index("mean_" + factor.key().to_string())] /= count_M;
}

// Calculate in-place disease prevalence and incidence rates.
for (const auto &disease : context.diseases()) {
calculated_stats_[i + get_channel_index("prevalence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + get_channel_index("prevalence_" + disease.code.to_string())] /=
count_M;
calculated_stats_[i + get_channel_index("incidence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + get_channel_index("incidence_" + disease.code.to_string())] /=
count_M;
}

// Calculate in-place YLL/YLD/DALY averages.
for (const auto &column : {"mean_yll", "mean_yld", "mean_daly"}) {
calculated_stats_[i + get_channel_index(column)] /= (count_F + deaths_F);
calculated_stats_[i + get_channel_index(column)] /= (count_M + deaths_M);
}
}
}

Expand Down Expand Up @@ -531,6 +583,26 @@ void AnalysisModule::classify_weight(DataSeries &series, const Person &entity) c
}
}

void AnalysisModule::classify_weight(const Person &person) {
auto weight_class = weight_classifier_.classify_weight(person);
switch (weight_class) {
case WeightCategory::normal:
calculated_stats_[get_channel_index("normal_weight")]++;
break;
case WeightCategory::overweight:
calculated_stats_[get_channel_index("over_weight")]++;
calculated_stats_[get_channel_index("above_weight")]++;
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
break;
case WeightCategory::obese:
calculated_stats_[get_channel_index("obese_weight")]++;
calculated_stats_[get_channel_index("above_weight")]++;
break;
default:
throw std::logic_error("Unknown weight classification category.");
break;
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
}
}

void AnalysisModule::initialise_output_channels(RuntimeContext &context) {
if (!channels_.empty()) {
return;
Expand Down Expand Up @@ -560,6 +632,44 @@ void AnalysisModule::initialise_output_channels(RuntimeContext &context) {
channels_.emplace_back("std_yld");
channels_.emplace_back("mean_daly");
channels_.emplace_back("std_daly");

// Since we will be performing frequent lookups, we will store the strings and indexes in a map
// for quick access.
for (size_t i = 0; i < channels_.size(); i++) {
channel_index_.emplace(channels_[i], i);
}
}

size_t AnalysisModule::calculate_index(const Person &person) const {
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(), size_t{1},
std::multiplies<>());
index += bin_indices[i] * accumulated_bins * channels_.size();
}
index += bin_indices.back() * channels_.size();

return index;
}

size_t AnalysisModule::get_channel_index(const std::string &channel) const {
auto it = channel_index_.find(channel);
if (it == channel_index_.end()) {
throw std::out_of_range("Unknown channel: " + channel);
}

return it->second;
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
}

std::unique_ptr<AnalysisModule> build_analysis_module(Repository &repository,
Expand Down
12 changes: 12 additions & 0 deletions src/HealthGPS/analysis_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class AnalysisModule final : public UpdatableModule {
WeightModel weight_classifier_;
DoubleAgeGenderTable residual_disability_weight_;
std::vector<std::string> channels_;
std::unordered_map<std::string, int> channel_index_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not be completely understanding what these data structures do... But am I right in thinking channel_index_ contains indexes into channels_? If so, could you just make channels_ a std::map<std::string, std::string> (I'm guessing it needs to be ordered...) and drop channels_index_?

unsigned int comorbidities_;
std::string name_{"Analysis"};
std::vector<core::Identifier> factors_to_calculate_ = {"Gender"_id, "Age"_id};
Expand All @@ -70,8 +71,19 @@ class AnalysisModule final : public UpdatableModule {
void calculate_population_statistics(RuntimeContext &context, DataSeries &series) const;

void classify_weight(hgps::DataSeries &series, const hgps::Person &entity) const;
void classify_weight(const Person &person);
void initialise_output_channels(RuntimeContext &context);

/// @brief Calculates the bin index in `calculated_stats_` for a given person
/// @param person The person to calculate the index for
/// @return The index in `calculated_stats_`
size_t calculate_index(const Person &person) const;

/// @brief Gets the index of the given channel name
/// @param channel The channel name
/// @return The channel index
size_t get_channel_index(const std::string &channel) const;

/// @brief Calculates the standard deviation of factors given data series containing means
/// @param context The runtime context
/// @param series The data series containing factor means
Expand Down
Loading