From 6583662baa5486aa379ba25d4d1e8af0001a81cf Mon Sep 17 00:00:00 2001 From: saraloo <45245630+saraloo@users.noreply.github.com> Date: Tue, 21 May 2024 09:48:01 -0400 Subject: [PATCH 1/4] fix --- flepimop/gempyor_pkg/src/gempyor/outcomes.py | 23 ++++++++++++++++---- utilities/prune_by_llik.py | 3 ++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/outcomes.py b/flepimop/gempyor_pkg/src/gempyor/outcomes.py index d29d8f395..0b1d85fb9 100644 --- a/flepimop/gempyor_pkg/src/gempyor/outcomes.py +++ b/flepimop/gempyor_pkg/src/gempyor/outcomes.py @@ -251,10 +251,25 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): return parameters -def postprocess_and_write(sim_id, modinf, outcomes_df, hpar, npi): - outcomes_df["time"] = outcomes_df["date"] - modinf.write_simID(ftype="hosp", sim_id=sim_id, df=outcomes_df) - modinf.write_simID(ftype="hpar", sim_id=sim_id, df=hpar) +def postprocess_and_write(sim_id, modinf, outcomes_df, hpar, npi, write=True): + if write: + # ADDED CODE + #outcomes_df = outcomes_df.set_index("date") + #reg = .8 + #mult=3 + #print("reg is", reg) + #for sp in outcomes_df["subpop"].unique(): + # max_fit = outcomes_df[outcomes_df["subpop"]==sp]["incidC"][:"2024-04-08"].max()*reg # HERE MULTIPLIED BY A REG factor: .9 + # max_summer = outcomes_df[outcomes_df["subpop"]==sp]["incidC"]["2024-04-08":"2024-09-30"].max() + # if max_summer > max_fit: + # print(f"changing {sp} because max_summer max_summer={max_summer:.1f} > reg*max_fit={max_fit:.1f}, diff {max_fit/max_summer*100:.1f}%") + # print(f">>> MULT BY {max_summer/max_fit*mult:2f}") + # outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]] = outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]]*max_summer/max_fit*mult + + #outcomes_df = outcomes_df.reset_index() + # END ADDED CODE; DELETE THIS PATCH + modinf.write_simID(ftype="hosp", sim_id=sim_id, df=outcomes_df) + modinf.write_simID(ftype="hpar", sim_id=sim_id, df=hpar) if npi is None: hnpi = pd.DataFrame( diff --git a/utilities/prune_by_llik.py b/utilities/prune_by_llik.py index 9b1b18929..5b1f3224b 100644 --- a/utilities/prune_by_llik.py +++ b/utilities/prune_by_llik.py @@ -143,7 +143,8 @@ def copy_path(src, dst): file_types = [ "llik", - "seed", + #"seed", + "init", "snpi", "hnpi", "spar", From c8c9d97d769083b1605cec03b2cdb14a5dbb1dab Mon Sep 17 00:00:00 2001 From: saraloo <45245630+saraloo@users.noreply.github.com> Date: Tue, 21 May 2024 18:25:50 -0400 Subject: [PATCH 2/4] test config check --- .../gempyor_pkg/src/gempyor/check_config.py | 306 ++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 flepimop/gempyor_pkg/src/gempyor/check_config.py diff --git a/flepimop/gempyor_pkg/src/gempyor/check_config.py b/flepimop/gempyor_pkg/src/gempyor/check_config.py new file mode 100644 index 000000000..1be3b9057 --- /dev/null +++ b/flepimop/gempyor_pkg/src/gempyor/check_config.py @@ -0,0 +1,306 @@ +import yaml +from pydantic import BaseModel, ValidationError, model_validator, Field, AfterValidator, validator +from datetime import date +from typing import Dict, List, Union, Literal, Optional, Annotated, Any +from functools import partial +from gempyor import compartments + +def read_yaml(file_path: str) -> dict: + with open(file_path, 'r') as stream: + config = yaml.safe_load(stream) + + return CheckConfig(**config).model_dump() + +def allowed_values(v, values): + assert v in values + return v + +# def parse_value(cls, values): +# value = values.get('value') +# parsed_val = compartments.Compartments.parse_parameter_strings_to_numpy_arrays_v2(value) +# return parsed_val + +class SubpopSetupConfig(BaseModel): + geodata: str + mobility: Optional[str] + selected: List[str] = Field(default_factory=list) + # state_level: Optional[bool] = False # pretty sure this doesn't exist anymore + +class InitialConditionsConfig(BaseModel): + method: Annotated[str, AfterValidator(partial(allowed_values, values=['Default', 'SetInitialConditions', 'SetInitialConditionsFolderDraw', 'InitialConditionsFolderDraw', 'FromFile', 'plugin']))] = 'Default' + initial_file_type: Optional[str] + initial_conditions_file: Annotated[str, AfterValidator(partial(allowed_values, values=['Default', 'SetInitialConditions', 'SetInitialConditionsFolderDraw', 'InitialConditionsFolderDraw', 'FromFile']))] = None + proportional: Optional[bool] = None + allow_missing_subpops: Optional[bool] = None + allow_missing_compartments: Optional[bool] = None + ignore_population_checks: Optional[bool] = None + plugin_file_path: Optional[str] = None + + @model_validator(mode='before') + def validate_initial_file_check(cls, values): + method = values.get('method') + initial_conditions_file = values.get('initial_conditions_file') + initial_file_type = values.get('initial_file_type') + if method in {'FromFile', 'SetInitialConditions'} and not initial_conditions_file: + raise ValueError('An initial_conditions_file is required when method is FromFile') + if method in {'InitialConditionsFolderDraw','SetInitialConditionsFolderDraw'} and not initial_file_type: + raise ValueError('initial_file_type is required when method is InitialConditionsFolderDraw') + return values + + @model_validator(mode='before') + def plugin_filecheck(cls, values): + method = values.get('method') + plugin_file_path = values.get('plugin_file_path') + if method == 'plugin' and not plugin_file_path: + raise ValueError('a plugin file path is required when method is plugin') + return values + + +class SeedingConfig(BaseModel): + method: Annotated[str, AfterValidator(partial(allowed_values, values=['NoSeeding', 'NegativeBinomialDistributed', 'PoissonDistributed', 'FolderDraw', 'FromFile', 'plugin']))] = 'NoSeeding' + plugin_file_path: Optional[str] = None + + @model_validator(mode='before') + def plugin_filecheck(cls, values): + method = values.get('method') + plugin_file_path = values.get('plugin_file_path') + if method == 'plugin' and not plugin_file_path: + raise ValueError('a plugin file path is required when method is plugin') + return values + +class IntegrationConfig(BaseModel): + method: Annotated[str, AfterValidator(partial(allowed_values, values=['rk4', 'rk4.jit', 'best.current', 'legacy']))] = 'rk4' + dt: float = 2.0 + +class ValueConfig(BaseModel): + distribution: str = 'fixed' + value: Optional[float] = None + mean: Optional[float] = None + sd: Optional[float] = None + a: Optional[float] = None + b: Optional[float] = None + # NEED TO ADD ABILITY TO PARSE PARAMETERS + + @model_validator(mode='before') + def check_distr(cls, values): + distr = values.get('distribution') + value = values.get('value') + mean = values.get('mean') + sd = values.get('sd') + a = values.get('a') + b = values.get('b') + if distr != 'fixed': + if not mean and not sd: + raise ValueError('mean and sd must be provided for non-fixed distributions') + if distr == 'truncnorm' and not a and not b: + raise ValueError('a and b must be provided for truncated normal distributions') + return values + +class BaseParameterConfig(BaseModel): + value: Optional[ValueConfig] = None + +class SeirParameterConfig(BaseParameterConfig): + value: Optional[ValueConfig] = None + stacked_modifier_method: Annotated[str, AfterValidator(partial(allowed_values, values=['sum', 'product', 'reduction_product']))] = None + rolling_mean_windows: Optional[float] = None + timeseries: Optional[str] = None + + @model_validator(mode='before') + def which_value(cls, values): + value = values.get('value') + timeseries = values.get('timeseries') + if value and timeseries: + raise ValueError('your parameter is both a timeseries and a value, please choose one') + return values + + +class TransitionConfig(BaseModel): + # !! sometimes these are lists of lists and sometimes they are lists... how to deal with this? + source: List[List[str]] + destination: List[List[str]] + proportion_exponent: List[List[str]] + proportional_to: List[str] + +class SeirConfig(BaseModel): + integration: IntegrationConfig # is this Optional? + parameters: Dict[str, SeirParameterConfig] + transitions: List[TransitionConfig] + +class SinglePeriodModifierConfig(BaseModel): + method: Literal["SinglePeriodModifier"] + parameter: str + period_start_date: date + period_end_date: date + subpop: str + subpop_groups: Optional[str] = None + value: ValueConfig + perturbation: Optional[ValueConfig] = None + +class MultiPeriodDatesConfig(BaseModel): + start_date: date + end_date: date + +class MultiPeriodGroupsConfig(BaseModel): + subpop: List[str] + periods: List[MultiPeriodDatesConfig] + +class MultiPeriodModifierConfig(BaseModel): + method: Literal["MultiPeriodModifier"] + parameter: str + groups: List[MultiPeriodGroupsConfig] + period_start_date: date + period_end_date: date + subpop: str + value: ValueConfig + perturbation: Optional[ValueConfig] = None + +class StackedModifierConfig(BaseModel): + method: Literal["StackedModifier"] + modifiers: List[str] + +class ModifiersConfig(BaseModel): + scenarios: List[str] + modifiers: Dict[str, Any] + + @validator("modifiers") + def validate_data_dict(cls, value: Dict[str, Any]) -> Dict[str, Any]: + errors = [] + for key, entry in value.items(): + method = entry.get("method") + if method not in {"SinglePeriodModifier", "MultiPeriodModifier", "StackedModifier"}: + errors.append(f"Invalid modifier method: {method}") + if errors: + raise ValueError("Errors in dictionary entries:\n" + "\n".join(errors)) + return value + + +class SourceConfig(BaseModel): # i think this can be incidence or prevalence, or any other source name? (this one is maybe a bit complicated to validate...) + incidence: Dict[str, str] + # TO FIX + + def get_source_names(self): + source_names = [] + for key in self.incidence: + source_names.append(key) + return source_names # Access keys using a loop + # def get_source_names(self): + # return self.incidence.keys() + +class DelayFrameConfig(BaseModel): + source: Optional[SourceConfig] = None + probability: Optional[BaseParameterConfig] = None + delay: Optional[BaseParameterConfig] = None + duration: Optional[BaseParameterConfig] = None + name: Optional[str] = None + sum: Optional[List[str]] = None + + # @validator("sum") + # def validate_sum_elements(cls, value: Optional[List[str]]) -> Optional[List[str]]: + # if value is None: + # return None + # # source = value.get('source') + # source_names = {name for name in cls.source.get_source_names()} # Get source names from source config + # invalid_elements = [element for element in value if element not in source_names] + # if invalid_elements: + # raise ValueError(f"Invalid elements in 'sum': {', '.join(invalid_elements)} not found in source names") + # return value + # NOTE: ^^ this doesn't work yet because it needs to somehow be a level above? to access all OTHER source names + + @model_validator(mode='before') + def check_outcome_type(cls, values): + sum_present = values.get('sum') is not None + source_present = values.get('source') is not None + + if sum_present and source_present: + raise ValueError(f"Error in outcome: Both 'sum' and 'source' are present. Choose one.") + elif not sum_present and not source_present: + raise ValueError(f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one.") + return values + + +class OutcomesConfig(BaseModel): + method: Literal["delayframe"] # Is this required? I don't see it anywhere in the gempyor code + outcomes: Dict[str, DelayFrameConfig] + +class ResampleConfig(BaseModel): + aggregator: str + freq: str + skipna: bool = False + +class LikelihoodParams(BaseModel): + scale: float + # are there other options here? + +class LikelihoodConfig(BaseModel): + dist: str + params: Optional[LikelihoodParams] = None + +class StatisticsConfig(BaseModel): + name: str + sim_var: str + data_var: str + aggregator: Optional[str] = None + period: Optional[str] = None + remove_na: Optional[bool] = None + add_one: Optional[bool] = None + # resample: Optional[ResampleConfig] = None + # zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? + likelihood: LikelihoodConfig + +class InferenceConfig(BaseModel): + method: Optional[str] = None # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options + iterations_per_slot: Optional[int] # i think this is optional because it is also set in command line?? + do_inference: bool + gt_data_path: str + statistics: Dict[str, StatisticsConfig] + # Need to determine here what is needed in classical vs other applications + +class CheckConfig(BaseModel): + name: str + setup_name: Optional[str] = None + model_output_dirname: Optional[str] = None + start_date: date + end_date: date + start_date_groundtruth: Optional[date] = None + end_date_groundtruth: Optional[date] = None + nslots: Optional[int] = 1 + subpop_setup: SubpopSetupConfig + compartments: Dict[str, List[str]] + initial_conditions: Optional[InitialConditionsConfig] = None + seeding: Optional[SeedingConfig] = None + seir: SeirConfig + seir_modifiers: Optional[ModifiersConfig] = None + outcomes: Optional[OutcomesConfig] = None + outcome_modifiers: Optional[ModifiersConfig] = None + inference: Optional[InferenceConfig] = None + +# add validator for if modifiers exist but seir/outcomes do not + +# there is an error in the one below + @model_validator(mode='before') + def verify_inference(cls, values): + inference_present = values.get('inference') is not None + start_date_groundtruth = values.get('start_date_groundtruth') is not None + if inference_present and not start_date_groundtruth: + raise ValueError('Inference mode is enabled but no groundtruth dates are provided') + elif start_date_groundtruth and not inference_present: + raise ValueError('Groundtruth dates are provided but inference mode is not enabled') + return values + + @model_validator(mode='before') + def check_dates(cls, values): + start_date = values.get('start_date') + end_date = values.get('end_date') + if start_date and end_date: + if end_date <= start_date: + raise ValueError('end_date must be greater than start_date') + return values + + @model_validator(mode='before') + def init_or_seed(cls, values): + init = values.get('initial_conditions') + seed = values.get('seeding') + if not init or seed: + raise ValueError('either initial_conditions or seeding must be provided') + return values + From 4a8e9d46a17b1f75a993bdc15f5dfd96b924247c Mon Sep 17 00:00:00 2001 From: saraloo <45245630+saraloo@users.noreply.github.com> Date: Thu, 30 May 2024 12:02:45 -0400 Subject: [PATCH 3/4] add some more checks in line with new_inference branch --- .../gempyor_pkg/src/gempyor/check_config.py | 129 +++++++++++------- 1 file changed, 83 insertions(+), 46 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/check_config.py b/flepimop/gempyor_pkg/src/gempyor/check_config.py index 1be3b9057..95f6e3029 100644 --- a/flepimop/gempyor_pkg/src/gempyor/check_config.py +++ b/flepimop/gempyor_pkg/src/gempyor/check_config.py @@ -28,8 +28,8 @@ class SubpopSetupConfig(BaseModel): class InitialConditionsConfig(BaseModel): method: Annotated[str, AfterValidator(partial(allowed_values, values=['Default', 'SetInitialConditions', 'SetInitialConditionsFolderDraw', 'InitialConditionsFolderDraw', 'FromFile', 'plugin']))] = 'Default' - initial_file_type: Optional[str] - initial_conditions_file: Annotated[str, AfterValidator(partial(allowed_values, values=['Default', 'SetInitialConditions', 'SetInitialConditionsFolderDraw', 'InitialConditionsFolderDraw', 'FromFile']))] = None + initial_file_type: Optional[str] = None + initial_conditions_file: Optional[str] = None proportional: Optional[bool] = None allow_missing_subpops: Optional[bool] = None allow_missing_compartments: Optional[bool] = None @@ -42,9 +42,9 @@ def validate_initial_file_check(cls, values): initial_conditions_file = values.get('initial_conditions_file') initial_file_type = values.get('initial_file_type') if method in {'FromFile', 'SetInitialConditions'} and not initial_conditions_file: - raise ValueError('An initial_conditions_file is required when method is FromFile') + raise ValueError(f'Error in InitialConditions: An initial_conditions_file is required when method is {method}') if method in {'InitialConditionsFolderDraw','SetInitialConditionsFolderDraw'} and not initial_file_type: - raise ValueError('initial_file_type is required when method is InitialConditionsFolderDraw') + raise ValueError(f'Error in InitialConditions: initial_file_type is required when method is {method}') return values @model_validator(mode='before') @@ -52,20 +52,37 @@ def plugin_filecheck(cls, values): method = values.get('method') plugin_file_path = values.get('plugin_file_path') if method == 'plugin' and not plugin_file_path: - raise ValueError('a plugin file path is required when method is plugin') + raise ValueError('Error in InitialConditions: a plugin file path is required when method is plugin') return values class SeedingConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['NoSeeding', 'NegativeBinomialDistributed', 'PoissonDistributed', 'FolderDraw', 'FromFile', 'plugin']))] = 'NoSeeding' + method: Annotated[str, AfterValidator(partial(allowed_values, values=['NoSeeding', 'PoissonDistributed', 'FolderDraw', 'FromFile', 'plugin']))] = 'NoSeeding' # note: removed NegativeBinomialDistributed because no longer supported + lambda_file: Optional[str] = None + seeding_file_type: Optional[str] = None + seeding_file: Optional[str] = None plugin_file_path: Optional[str] = None + @model_validator(mode='before') + def validate_seedingfile(cls, values): + method = values.get('method') + lambda_file = values.get('lambda_file') + seeding_file_type = values.get('seeding_file_type') + seeding_file = values.get('seeding_file') + if method == 'PoissonDistributed' and not lambda_file: + raise ValueError(f'Error in Seeding: A lambda_file is required when method is {method}') + if method == 'FolderDraw' and not seeding_file_type: + raise ValueError('Error in Seeding: A seeding_file_type is required when method is FolderDraw') + if method == 'FromFile' and not seeding_file: + raise ValueError('Error in Seeding: A seeding_file is required when method is FromFile') + return values + @model_validator(mode='before') def plugin_filecheck(cls, values): method = values.get('method') plugin_file_path = values.get('plugin_file_path') if method == 'plugin' and not plugin_file_path: - raise ValueError('a plugin file path is required when method is plugin') + raise ValueError('Error in Seeding: a plugin file path is required when method is plugin') return values class IntegrationConfig(BaseModel): @@ -74,12 +91,11 @@ class IntegrationConfig(BaseModel): class ValueConfig(BaseModel): distribution: str = 'fixed' - value: Optional[float] = None + value: Optional[float] = None # NEED TO ADD ABILITY TO PARSE PARAMETERS mean: Optional[float] = None sd: Optional[float] = None a: Optional[float] = None b: Optional[float] = None - # NEED TO ADD ABILITY TO PARSE PARAMETERS @model_validator(mode='before') def check_distr(cls, values): @@ -91,13 +107,17 @@ def check_distr(cls, values): b = values.get('b') if distr != 'fixed': if not mean and not sd: - raise ValueError('mean and sd must be provided for non-fixed distributions') + raise ValueError('Error in value: mean and sd must be provided for non-fixed distributions') if distr == 'truncnorm' and not a and not b: - raise ValueError('a and b must be provided for truncated normal distributions') + raise ValueError('Error in value: a and b must be provided for truncated normal distributions') + if distr == 'fixed' and not value: + raise ValueError('Error in value: value must be provided for fixed distributions') return values class BaseParameterConfig(BaseModel): value: Optional[ValueConfig] = None + modifier_parameter: Optional[str] = None + name: Optional[str] = None # this is only for outcomes, to build outcome_prevalence_name (how to restrict this?) class SeirParameterConfig(BaseParameterConfig): value: Optional[ValueConfig] = None @@ -107,10 +127,10 @@ class SeirParameterConfig(BaseParameterConfig): @model_validator(mode='before') def which_value(cls, values): - value = values.get('value') - timeseries = values.get('timeseries') + value = values.get('value') is not None + timeseries = values.get('timeseries') is not None if value and timeseries: - raise ValueError('your parameter is both a timeseries and a value, please choose one') + raise ValueError('Error in seir::parameters: your parameter is both a timeseries and a value, please choose one') return values @@ -118,12 +138,13 @@ class TransitionConfig(BaseModel): # !! sometimes these are lists of lists and sometimes they are lists... how to deal with this? source: List[List[str]] destination: List[List[str]] + rate: List[List[str]] proportion_exponent: List[List[str]] proportional_to: List[str] class SeirConfig(BaseModel): integration: IntegrationConfig # is this Optional? - parameters: Dict[str, SeirParameterConfig] + parameters: Dict[str, SeirParameterConfig] # there was a previous issue that gempyor doesn't work if there are no parameters (eg if just numbers are used in the transitions) - do we want to get around this? transitions: List[TransitionConfig] class SinglePeriodModifierConfig(BaseModel): @@ -142,15 +163,13 @@ class MultiPeriodDatesConfig(BaseModel): class MultiPeriodGroupsConfig(BaseModel): subpop: List[str] + subpop_groups: Optional[str] = None periods: List[MultiPeriodDatesConfig] class MultiPeriodModifierConfig(BaseModel): method: Literal["MultiPeriodModifier"] parameter: str groups: List[MultiPeriodGroupsConfig] - period_start_date: date - period_end_date: date - subpop: str value: ValueConfig perturbation: Optional[ValueConfig] = None @@ -162,7 +181,7 @@ class ModifiersConfig(BaseModel): scenarios: List[str] modifiers: Dict[str, Any] - @validator("modifiers") + @field_validator("modifiers") def validate_data_dict(cls, value: Dict[str, Any]) -> Dict[str, Any]: errors = [] for key, entry in value.items(): @@ -170,29 +189,37 @@ def validate_data_dict(cls, value: Dict[str, Any]) -> Dict[str, Any]: if method not in {"SinglePeriodModifier", "MultiPeriodModifier", "StackedModifier"}: errors.append(f"Invalid modifier method: {method}") if errors: - raise ValueError("Errors in dictionary entries:\n" + "\n".join(errors)) + raise ValueError("Errors in modifiers:\n" + "\n".join(errors)) return value -class SourceConfig(BaseModel): # i think this can be incidence or prevalence, or any other source name? (this one is maybe a bit complicated to validate...) - incidence: Dict[str, str] - # TO FIX +class SourceConfig(BaseModel): # set up only for incidence or prevalence. Can this be any name? i don't think so atm + incidence: Dict[str, str] = None + prevalence: Dict[str, str] = None + # note: these dictionaries have to have compartment names... more complicated to set this up + + @model_validator(mode='before') + def which_source(cls, values): + incidence = values.get('incidence') + prevalence = values.get('prevalence') + if incidence and prevalence: + raise ValueError('Error in outcomes::source. Can only be incidence or prevalence, not both.') + return values - def get_source_names(self): - source_names = [] - for key in self.incidence: - source_names.append(key) - return source_names # Access keys using a loop + # @model_validator(mode='before') # DOES NOT WORK # def get_source_names(self): - # return self.incidence.keys() + # source_names = [] + # type = self.incidence or self.prevalence + # for key in type: + # source_names.append(key) + # return source_names # Access keys using a loop class DelayFrameConfig(BaseModel): source: Optional[SourceConfig] = None probability: Optional[BaseParameterConfig] = None delay: Optional[BaseParameterConfig] = None duration: Optional[BaseParameterConfig] = None - name: Optional[str] = None - sum: Optional[List[str]] = None + sum: Optional[List[str]] = None # only for sums of other outcomes # @validator("sum") # def validate_sum_elements(cls, value: Optional[List[str]]) -> Optional[List[str]]: @@ -204,7 +231,7 @@ class DelayFrameConfig(BaseModel): # if invalid_elements: # raise ValueError(f"Invalid elements in 'sum': {', '.join(invalid_elements)} not found in source names") # return value - # NOTE: ^^ this doesn't work yet because it needs to somehow be a level above? to access all OTHER source names + # note: ^^ this doesn't work yet because it needs to somehow be a level above? to access all OTHER source names @model_validator(mode='before') def check_outcome_type(cls, values): @@ -216,44 +243,54 @@ def check_outcome_type(cls, values): elif not sum_present and not source_present: raise ValueError(f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one.") return values - class OutcomesConfig(BaseModel): method: Literal["delayframe"] # Is this required? I don't see it anywhere in the gempyor code + param_from_file: Optional[bool] = None + param_subpop_file: Optional[str] = None outcomes: Dict[str, DelayFrameConfig] + + @model_validator(mode='before') + def check_paramfromfile_type(cls, values): + param_from_file = values.get('param_from_file') is not None + param_subpop_file = values.get('param_subpop_file') is not None + + if param_from_file and not param_subpop_file: + raise ValueError(f"Error in outcome: 'param_subpop_file' is required when 'param_from_file' is True") + return values class ResampleConfig(BaseModel): - aggregator: str - freq: str - skipna: bool = False + aggregator: Optional[str] = None + freq: Optional[str] = None + skipna: Optional[bool] = False class LikelihoodParams(BaseModel): scale: float # are there other options here? +class LikelihoodReg(BaseModel): + name: str + class LikelihoodConfig(BaseModel): - dist: str + dist: Annotated[str, AfterValidator(partial(allowed_values, values=['pois', 'norm', 'norm_cov', 'nbinom', 'rmse', 'absolute_error']))] = None params: Optional[LikelihoodParams] = None class StatisticsConfig(BaseModel): name: str sim_var: str data_var: str - aggregator: Optional[str] = None - period: Optional[str] = None - remove_na: Optional[bool] = None - add_one: Optional[bool] = None - # resample: Optional[ResampleConfig] = None - # zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? + regularize: Optional[LikelihoodReg] = None + resample: Optional[ResampleConfig] = None + scale: Optional[float] = None # is scale here or at likelihood level? + zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? likelihood: LikelihoodConfig class InferenceConfig(BaseModel): - method: Optional[str] = None # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options + method: Annotated[str, AfterValidator(partial(allowed_values, values=['emcee', 'default', 'classical']))] = 'default' # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options iterations_per_slot: Optional[int] # i think this is optional because it is also set in command line?? do_inference: bool gt_data_path: str statistics: Dict[str, StatisticsConfig] - # Need to determine here what is needed in classical vs other applications class CheckConfig(BaseModel): name: str @@ -264,6 +301,7 @@ class CheckConfig(BaseModel): start_date_groundtruth: Optional[date] = None end_date_groundtruth: Optional[date] = None nslots: Optional[int] = 1 + subpop_setup: SubpopSetupConfig compartments: Dict[str, List[str]] initial_conditions: Optional[InitialConditionsConfig] = None @@ -303,4 +341,3 @@ def init_or_seed(cls, values): if not init or seed: raise ValueError('either initial_conditions or seeding must be provided') return values - From 567cffa8abc4f704112a988d5c4974e39b0644bd Mon Sep 17 00:00:00 2001 From: saraloo <45245630+saraloo@users.noreply.github.com> Date: Tue, 4 Jun 2024 08:39:45 -0400 Subject: [PATCH 4/4] change filename --- .../src/gempyor/{check_config.py => config_validator.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename flepimop/gempyor_pkg/src/gempyor/{check_config.py => config_validator.py} (100%) diff --git a/flepimop/gempyor_pkg/src/gempyor/check_config.py b/flepimop/gempyor_pkg/src/gempyor/config_validator.py similarity index 100% rename from flepimop/gempyor_pkg/src/gempyor/check_config.py rename to flepimop/gempyor_pkg/src/gempyor/config_validator.py