diff --git a/docs/pages/concepts/database.md b/docs/pages/concepts/database.md index be1c8235f..818947224 100644 --- a/docs/pages/concepts/database.md +++ b/docs/pages/concepts/database.md @@ -1,3 +1,10 @@ +## Basic Concepts of the Database + +We use the `JsonModel` class to define the structure of the database. This class is coming from [Redis-OM](https://github.com/redis/redis-om-node), which is a Redis Object Mapper for Node.js. You should learn a lot about how to use the database by reading their documentation. + +We do have more customized method `.all()` would return a list of all the corresponding objects in the database. Note that using `.all()` is going to block the database, so you should use it with caution. +You might see codes using `all_pks()` which is a method from the `JsonModel` class. This method returns a list of all the primary keys of the corresponding objects in the database. This method is not blocking the database. However, it might not be as efficient as using `.all()`. + ## Adding new characters and environments You can use the following function with the `**kwargs` being the properties of the `AgentProfile` class. This is the same for the scenarios/environments. ```python @@ -98,10 +105,7 @@ It is very easy to serialize any database structures to JSON or CSV. ```python from sotopia.database import episodes_to_jsonl, EpisodeLog -episodes: list[EpisodeLog] = [ - EpisodeLog.get(pk=pk) - for pk in EpisodeLog.all_pks() -] +episodes: list[EpisodeLog] = EpisodeLog.all() episodes_to_jsonl(episodes, "episodes.jsonl") ``` @@ -111,15 +115,9 @@ episodes_to_jsonl(episodes, "episodes.jsonl") ```python from sotopia.database import environmentprofiles_to_jsonl, agentprofiles_to_jsonl -agent_profiles: list[AgentProfile] = [ - AgentProfile.get(pk=pk) - for pk in AgentProfile.all_pks() -] +agent_profiles: list[AgentProfile] = AgentProfile.all() -environment_profiles: list[EnvironmentProfile] = [ - EnvironmentProfile.get(pk=pk) - for pk in EnvironmentProfile.all_pks() -] +environment_profiles: list[EnvironmentProfile] = EnvironmentProfile.all() agentprofiles_to_jsonl(agent_profiles, "agent_profiles.jsonl") environmentprofiles_to_jsonl(environment_profiles, "environment_profiles.jsonl") diff --git a/examples/benchmark_evaluator.py b/examples/benchmark_evaluator.py index e07ce4efd..b35ce2247 100644 --- a/examples/benchmark_evaluator.py +++ b/examples/benchmark_evaluator.py @@ -25,8 +25,8 @@ def get_human_annotations( target_model_patterns: list[list[str]], ) -> list[AnnotationForEpisode]: episodes_with_human_annotation: list[AnnotationForEpisode] = [] - for pk in AnnotationForEpisode.all_pks(): - episode_human = AnnotationForEpisode.get(pk) + human_annotated_episodes = AnnotationForEpisode.all() + for episode_human in human_annotated_episodes: episode_model = EpisodeLog.get(episode_human.episode) if episode_model.models in target_model_patterns: episodes_with_human_annotation.append(episode_human) diff --git a/examples/evaluate_existing_episode.py b/examples/evaluate_existing_episode.py index 9a143bd99..8b80b4a43 100644 --- a/examples/evaluate_existing_episode.py +++ b/examples/evaluate_existing_episode.py @@ -98,10 +98,7 @@ def run_server( push_to_db: bool = True, verbose: bool = False, ) -> None: - annotated_episodes_pks = [ - AnnotationForEpisode.get(anno).episode - for anno in AnnotationForEpisode.all_pks() - ] + annotated_episodes_pks = [anno.episode for anno in AnnotationForEpisode.all()] annotated_episodes_pks = list(set(annotated_episodes_pks)) model = typing.cast(LLM_Name, model) # Call the function with the specified parameters diff --git a/sotopia/database/__init__.py b/sotopia/database/__init__.py index 34dd74de6..735d005ee 100644 --- a/sotopia/database/__init__.py +++ b/sotopia/database/__init__.py @@ -1,3 +1,5 @@ +from typing import TypeVar +from redis_om import JsonModel from .annotators import Annotator from .env_agent_combo_storage import EnvAgentComboStorage from .logs import AnnotationForEpisode, EpisodeLog @@ -60,3 +62,12 @@ "jsonl_to_envagnetcombostorage", "get_rewards_from_episode", ] + +InheritedJsonModel = TypeVar("InheritedJsonModel", bound="JsonModel") + + +def _json_model_all(cls: type[InheritedJsonModel]) -> list[InheritedJsonModel]: + return cls.find().all() # type: ignore[return-value] + + +JsonModel.all = classmethod(_json_model_all) # type: ignore[assignment,method-assign] diff --git a/sotopia/samplers/constraint_based_sampler.py b/sotopia/samplers/constraint_based_sampler.py index 48eaea74d..624520dcd 100644 --- a/sotopia/samplers/constraint_based_sampler.py +++ b/sotopia/samplers/constraint_based_sampler.py @@ -86,6 +86,12 @@ def sample( env_profiles: list[EnvironmentProfile] = [] agents_which_fit_scenario: list[list[str]] = [] + if self.env_candidates is None: + self.env_candidates = EnvironmentProfile.all() + + if self.agent_candidates is None: + self.agent_candidates = AgentProfile.all() + agent_candidate_ids: set[str] | None = None if self.agent_candidates: agent_candidate_ids = set( @@ -120,13 +126,9 @@ def sample( ) else: for _ in range(size): - if self.env_candidates: - env_profile = random.choice(self.env_candidates) - if isinstance(env_profile, str): - env_profile = EnvironmentProfile.get(env_profile) - else: - env_profile_id = random.choice(list(EnvironmentProfile.all_pks())) - env_profile = EnvironmentProfile.get(env_profile_id) + env_profile = random.choice(self.env_candidates) + if isinstance(env_profile, str): + env_profile = EnvironmentProfile.get(env_profile) env_profiles.append(env_profile) env_profile_id = env_profile.pk assert env_profile_id, "Env candidate must have an id" diff --git a/sotopia/samplers/uniform_sampler.py b/sotopia/samplers/uniform_sampler.py index 9e4f36116..22c89f3c3 100644 --- a/sotopia/samplers/uniform_sampler.py +++ b/sotopia/samplers/uniform_sampler.py @@ -46,32 +46,19 @@ def sample( assert replacement, "Uniform sampling without replacement is not supported yet" + if self.env_candidates is None: + self.env_candidates = EnvironmentProfile.all() + + if self.agent_candidates is None: + self.agent_candidates = AgentProfile.all() + for _ in range(size): - if self.env_candidates: - env_profile = random.choice(self.env_candidates) - if isinstance(env_profile, str): - env_profile = EnvironmentProfile.get(env_profile) - else: - env_profile_id = random.choice(list(EnvironmentProfile.all_pks())) - env_profile = EnvironmentProfile.get(env_profile_id) + env_profile = random.choice(self.env_candidates) + if isinstance(env_profile, str): + env_profile = EnvironmentProfile.get(env_profile) env = ParallelSotopiaEnv(env_profile=env_profile, **env_params) - if self.agent_candidates: - agent_profile_candidates = self.agent_candidates - if len(agent_profile_candidates) < n_agent: - raise ValueError( - f"Number of agent candidates ({len(agent_profile_candidates)}) is less than number of agents ({n_agent})" - ) - else: - agent_profile_candidates_keys = list(AgentProfile.all_pks()) - if len(agent_profile_candidates_keys) < n_agent: - raise ValueError( - f"Number of agent profile candidates ({len(agent_profile_candidates_keys)}) in database is less than number of agents ({n_agent})" - ) - agent_profile_candidates = [ - AgentProfile.get(pk=pk) for pk in agent_profile_candidates_keys - ] - + agent_profile_candidates = self.agent_candidates if len(agent_profile_candidates) == n_agent: agent_profiles_maybe_id = agent_profile_candidates else: diff --git a/stubs/redis_om/__init__.pyi b/stubs/redis_om/__init__.pyi index 230b0e48c..038ca7ac4 100644 --- a/stubs/redis_om/__init__.pyi +++ b/stubs/redis_om/__init__.pyi @@ -27,6 +27,8 @@ class JsonModel(RedisModel, abc.ABC): @classmethod def all_pks(cls) -> Generator[str, None, None]: ... @classmethod + def all(cls: type[InheritedJsonModel]) -> list[InheritedJsonModel]: ... + @classmethod def find(cls, *args: Any, **kwargs: Any) -> FindQuery: ... def save(self) -> None: ...