From 4906efe848c1baa7a7db048c4ed2884962c0edae Mon Sep 17 00:00:00 2001 From: XuhuiZhou Date: Tue, 28 May 2024 14:30:38 -0700 Subject: [PATCH] modify cli; fix model position bug --- sotopia/benchmark/cli.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sotopia/benchmark/cli.py b/sotopia/benchmark/cli.py index ec598ec2e..fc768e2d1 100644 --- a/sotopia/benchmark/cli.py +++ b/sotopia/benchmark/cli.py @@ -54,6 +54,7 @@ def check_existing_episodes( assert isinstance(episode, EpisodeLog), "episode should be an EpisodeLog" if episode.agents == agent_ids and episode.models == list(models.values()): return True + breakpoint() return False else: return False @@ -70,7 +71,6 @@ def initilize_benchmark_combo(dataset) -> list[EnvAgentComboStorage]: # type: i def get_avg_reward(episodes: list[EpisodeLog], model_name: str) -> dict[str, float]: - breakpoint() rewards_list = [] avg_reward_dict = {} for episode in episodes: @@ -157,6 +157,7 @@ def run_async_benchmark_in_batch( model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo ) env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter) + breakpoint() env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db( model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo ) @@ -220,7 +221,7 @@ def run_async_benchmark_in_batch( number_of_fix_turns += 1 if env_agent_combo_iter_length == 0 or number_of_fix_turns >= 5: rewards_dict = get_avg_reward(simulated_episodes, model_names["agent2"]) # type: ignore - rewards_dict["model_name"] = model_names["agent2"] # type: ignore + rewards_dict["model_name"] = model_names["agent1"] # type: ignore print(rewards_dict) return @@ -265,7 +266,7 @@ def cli( """A simple command-line interface example.""" _set_up_logs(print_logs=print_logs) typer.echo( - f"Running benchmark for {model} and {partner_model} on task {task} with {evaluator_model} as the evaluator." + f"Running benchmark for {model} chatting with {partner_model} on task {task} with {evaluator_model} as the evaluator." ) model = cast(LLM_Name, model) partner_model = cast(LLM_Name, partner_model) @@ -273,7 +274,7 @@ def cli( tag = f"benchmark_{model}_q" run_async_benchmark_in_batch( batch_size=batch_size, - model_names={"agent1": model, "agent2": partner_model, "env": evaluator_model}, + model_names={"env": evaluator_model, "agent1": model, "agent2": partner_model}, tag=tag, verbose=False, )