Skip to content

Commit

Permalink
require schema-like args
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 21, 2024
1 parent 48013c5 commit 33b612d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
3 changes: 2 additions & 1 deletion tests/pytorch/test_define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def schema():
def test_ZoobotTree_init(schema):
model = define_model.ZoobotTree(
output_dim=12,
question_index_groups=schema.question_index_groups,
question_answer_pairs=schema.question_answer_pairs,
dependencies=schema.dependencies
)

17 changes: 7 additions & 10 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ class ZoobotTree(GenericLightningModule):
Args:
output_dim (int): Output dimension of model's head e.g. 34 for predicting a 34-answer decision tree.
question_index_groups (List): Mapping of which label indices are part of the same question. See :ref:`training_on_vote_counts`.
architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to "efficientnet_b0".
channels (int, optional): Num. input channels. Probably 3 or 1. Defaults to 1.
test_time_dropout (bool, optional): Apply dropout at test time, to pretend to be Bayesian. Defaults to True.
Expand All @@ -192,7 +191,7 @@ def __init__(
self,
output_dim: int,
# in the simplest case, this is all zoobot needs: grouping of label col indices as questions
question_index_groups: List=None,
# question_index_groups: List=None,
# BUT
# if you pass these, it enables better per-question and per-survey logging (because we have names)
# must be passed as simple dicts, not objects, so can't just pass schema in
Expand All @@ -219,7 +218,6 @@ def __init__(
super().__init__(
# these all do nothing, they are simply saved by lightning as hparams
output_dim,
question_index_groups,
question_answer_pairs,
dependencies,
architecture_name,
Expand All @@ -236,13 +234,12 @@ def __init__(

logging.info('Generic __init__ complete - moving to Zoobot __init__')

if question_answer_pairs is not None:
logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
# assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
assert dependencies is not None
self.schema = schemas.Schema(question_answer_pairs, dependencies)
# replace with schema-derived version
question_index_groups = self.schema.question_index_groups
# logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
# assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
assert dependencies is not None
self.schema = schemas.Schema(question_answer_pairs, dependencies)
# replace with schema-derived version
question_index_groups = self.schema.question_index_groups

self.setup_metrics()

Expand Down

0 comments on commit 33b612d

Please sign in to comment.