Skip to content

Commit

Permalink
modify cli; fix model position bug
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed May 28, 2024
1 parent 9e15790 commit 4906efe
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions sotopia/benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def check_existing_episodes(
assert isinstance(episode, EpisodeLog), "episode should be an EpisodeLog"
if episode.agents == agent_ids and episode.models == list(models.values()):
return True
breakpoint()
return False
else:
return False
Expand All @@ -70,7 +71,6 @@ def initilize_benchmark_combo(dataset) -> list[EnvAgentComboStorage]: # type: i


def get_avg_reward(episodes: list[EpisodeLog], model_name: str) -> dict[str, float]:
breakpoint()
rewards_list = []
avg_reward_dict = {}
for episode in episodes:
Expand Down Expand Up @@ -157,6 +157,7 @@ def run_async_benchmark_in_batch(
model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo
)
env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter)
breakpoint()
env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db(
model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo
)
Expand Down Expand Up @@ -220,7 +221,7 @@ def run_async_benchmark_in_batch(
number_of_fix_turns += 1
if env_agent_combo_iter_length == 0 or number_of_fix_turns >= 5:
rewards_dict = get_avg_reward(simulated_episodes, model_names["agent2"]) # type: ignore
rewards_dict["model_name"] = model_names["agent2"] # type: ignore
rewards_dict["model_name"] = model_names["agent1"] # type: ignore
print(rewards_dict)
return

Expand Down Expand Up @@ -265,15 +266,15 @@ def cli(
"""A simple command-line interface example."""
_set_up_logs(print_logs=print_logs)
typer.echo(
f"Running benchmark for {model} and {partner_model} on task {task} with {evaluator_model} as the evaluator."
f"Running benchmark for {model} chatting with {partner_model} on task {task} with {evaluator_model} as the evaluator."
)
model = cast(LLM_Name, model)
partner_model = cast(LLM_Name, partner_model)
evaluator_model = cast(LLM_Name, evaluator_model)
tag = f"benchmark_{model}_q"
run_async_benchmark_in_batch(
batch_size=batch_size,
model_names={"agent1": model, "agent2": partner_model, "env": evaluator_model},
model_names={"env": evaluator_model, "agent1": model, "agent2": partner_model},
tag=tag,
verbose=False,
)
Expand Down

0 comments on commit 4906efe

Please sign in to comment.