Skip to content

Commit

Permalink
feat(exp_eval): support tag for combo iteration (#245)
Browse files Browse the repository at this point in the history
* support tag for combo iteration

* fix pre-commit error

* [autofix.ci] apply automated fixes

* fix mypy error

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
lwaekfjlk and autofix-ci[bot] authored Nov 11, 2024
1 parent 19d39e0 commit 7d2ab2f
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions examples/experiment_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def _iterate_env_agent_combo_not_in_db(
)
assert env_agent_combo_storage_list
first_env_agent_combo_storage_to_run: EnvAgentComboStorage | None = None

env_agent_combo_storage_list = sorted(
env_agent_combo_storage_list, key=lambda x: str(x.pk)
)

for env_agent_combo_storage in env_agent_combo_storage_list:
env_agent_combo_storage = cast(
EnvAgentComboStorage, env_agent_combo_storage
Expand Down Expand Up @@ -183,10 +188,14 @@ def run_async_server_in_batch(
logger.removeHandler(rich_handler)

# we cannot get the exact length of the generator, we just give an estimate of the length
env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(model_names=model_names)
env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(
model_names=model_names, tag=tag
)
env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter)

env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(model_names=model_names)
env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(
model_names=model_names, tag=tag
)
env_agent_combo_batch: list[EnvAgentCombo[Observation, AgentAction]] = []

while True:
Expand Down

0 comments on commit 7d2ab2f

Please sign in to comment.