From 33b612d9ba535818cc0afbb2be5a3b2dd210b048 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Thu, 21 Mar 2024 14:52:21 -0400 Subject: [PATCH] require schema-like args --- tests/pytorch/test_define_model.py | 3 ++- zoobot/pytorch/estimators/define_model.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/test_define_model.py b/tests/pytorch/test_define_model.py index 3805777d..f7628d22 100644 --- a/tests/pytorch/test_define_model.py +++ b/tests/pytorch/test_define_model.py @@ -10,6 +10,7 @@ def schema(): def test_ZoobotTree_init(schema): model = define_model.ZoobotTree( output_dim=12, - question_index_groups=schema.question_index_groups, + question_answer_pairs=schema.question_answer_pairs, + dependencies=schema.dependencies ) diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index d04ab746..8670ba0e 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -177,7 +177,6 @@ class ZoobotTree(GenericLightningModule): Args: output_dim (int): Output dimension of model's head e.g. 34 for predicting a 34-answer decision tree. - question_index_groups (List): Mapping of which label indices are part of the same question. See :ref:`training_on_vote_counts`. architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to "efficientnet_b0". channels (int, optional): Num. input channels. Probably 3 or 1. Defaults to 1. test_time_dropout (bool, optional): Apply dropout at test time, to pretend to be Bayesian. Defaults to True. @@ -192,7 +191,7 @@ def __init__( self, output_dim: int, # in the simplest case, this is all zoobot needs: grouping of label col indices as questions - question_index_groups: List=None, + # question_index_groups: List=None, # BUT # if you pass these, it enables better per-question and per-survey logging (because we have names) # must be passed as simple dicts, not objects, so can't just pass schema in @@ -219,7 +218,6 @@ def __init__( super().__init__( # these all do nothing, they are simply saved by lightning as hparams output_dim, - question_index_groups, question_answer_pairs, dependencies, architecture_name, @@ -236,13 +234,12 @@ def __init__( logging.info('Generic __init__ complete - moving to Zoobot __init__') - if question_answer_pairs is not None: - logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__') - # assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" - assert dependencies is not None - self.schema = schemas.Schema(question_answer_pairs, dependencies) - # replace with schema-derived version - question_index_groups = self.schema.question_index_groups + # logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__') + # assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" + assert dependencies is not None + self.schema = schemas.Schema(question_answer_pairs, dependencies) + # replace with schema-derived version + question_index_groups = self.schema.question_index_groups self.setup_metrics()