Skip to content

Commit

Permalink
register_model is now async aware
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 6, 2024
1 parent b27b275 commit 44e6be1
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 28 deletions.
56 changes: 46 additions & 10 deletions llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
for alias, model_id in configured_aliases.items():
extra_model_aliases.setdefault(model_id, []).append(alias)

def register(model, aliases=None):
def register(model, async_model=None, aliases=None):
alias_list = list(aliases or [])
if model.model_id in extra_model_aliases:
alias_list.extend(extra_model_aliases[model.model_id])
model_aliases.append(ModelWithAliases(model, alias_list))
model_aliases.append(ModelWithAliases(model, async_model, alias_list))

pm.hook.register_models(register=register)

Expand Down Expand Up @@ -136,30 +136,66 @@ def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
return model_aliases


def get_async_model_aliases() -> Dict[str, AsyncModel]:
async_model_aliases = {}
for model_with_aliases in get_models_with_aliases():
if model_with_aliases.async_model:
for alias in model_with_aliases.aliases:
async_model_aliases[alias] = model_with_aliases.async_model
async_model_aliases[model_with_aliases.model.model_id] = (
model_with_aliases.async_model
)
return async_model_aliases


def get_model_aliases() -> Dict[str, Model]:
model_aliases = {}
for model_with_aliases in get_models_with_aliases():
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
if model_with_aliases.model:
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
return model_aliases


def get_async_model(model_id: str) -> AsyncModel:
return get_model(model_id).get_async_model()


class UnknownModelError(KeyError):
pass


def get_async_model(name: Optional[str] = None) -> AsyncModel:
aliases = get_async_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
# Does a sync model exist?
sync_model = None
try:
sync_model = get_model(name)
except UnknownModelError:
pass
if sync_model:
raise UnknownModelError("Unknown async model (sync model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)


def get_model(name: Optional[str] = None) -> Model:
aliases = get_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
raise UnknownModelError("Unknown model: " + name)
# Does an async model exist?
async_model = None
try:
async_model = get_async_model(name)
except UnknownModelError:
pass
if async_model:
raise UnknownModelError("Unknown model (async model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)


def get_key(
Expand Down
4 changes: 2 additions & 2 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def read_prompt():
model = get_async_model(model_id)
else:
model = get_model(model_id)
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
except UnknownModelError as ex:
raise click.ClickException(ex)

# Provide the API key, if one is needed and has been provided
if model.needs_key:
Expand Down
46 changes: 34 additions & 12 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,43 @@

@hookimpl
def register_models(register):
register(Chat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt"))
register(Chat("gpt-3.5-turbo-16k"), aliases=("chatgpt-16k", "3.5-16k"))
register(Chat("gpt-4"), aliases=("4", "gpt4"))
register(Chat("gpt-4-32k"), aliases=("4-32k",))
register(
Chat("gpt-3.5-turbo"), AsyncChat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt")
)
register(
Chat("gpt-3.5-turbo-16k"),
AsyncChat("gpt-3.5-turbo-16k"),
aliases=("chatgpt-16k", "3.5-16k"),
)
register(Chat("gpt-4"), AsyncChat("gpt-4"), aliases=("4", "gpt4"))
register(Chat("gpt-4-32k"), AsyncChat("gpt-4-32k"), aliases=("4-32k",))
# GPT-4 Turbo models
register(Chat("gpt-4-1106-preview"))
register(Chat("gpt-4-0125-preview"))
register(Chat("gpt-4-turbo-2024-04-09"))
register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t"))
register(Chat("gpt-4-1106-preview"), AsyncChat("gpt-4-1106-preview"))
register(Chat("gpt-4-0125-preview"), AsyncChat("gpt-4-0125-preview"))
register(Chat("gpt-4-turbo-2024-04-09"), AsyncChat("gpt-4-turbo-2024-04-09"))
register(
Chat("gpt-4-turbo"),
AsyncChat("gpt-4-turbo"),
aliases=("gpt-4-turbo-preview", "4-turbo", "4t"),
)
# GPT-4o
register(Chat("gpt-4o", vision=True), aliases=("4o",))
register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",))
register(
Chat("gpt-4o", vision=True), AsyncChat("gpt-4o", vision=True), aliases=("4o",)
)
register(
Chat("gpt-4o-mini", vision=True),
AsyncChat("gpt-4o", vision=True),
aliases=("4o-mini",),
)
# o1
register(Chat("o1-preview", can_stream=False, allows_system_prompt=False))
register(Chat("o1-mini", can_stream=False, allows_system_prompt=False))
register(
Chat("o1-preview", can_stream=False, allows_system_prompt=False),
AsyncChat("o1-preview", can_stream=False, allows_system_prompt=False),
)
register(
Chat("o1-mini", can_stream=False, allows_system_prompt=False),
AsyncChat("o1-mini", can_stream=False, allows_system_prompt=False),
)
# The -instruct completion model
register(
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
Expand Down
7 changes: 3 additions & 4 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,10 @@ def from_row(cls, db, row):
return response

def __repr__(self):
text = '... not yet awaited ...'
text = "... not yet awaited ..."
if self._done:
text = "".join(self._chunks)
return "<Response prompt='{}' text='{}'>".format(
self.prompt.prompt, text
)
return "<Response prompt='{}' text='{}'>".format(self.prompt.prompt, text)


class Options(BaseModel):
Expand Down Expand Up @@ -695,6 +693,7 @@ def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float
@dataclass
class ModelWithAliases:
model: Model
async_model: AsyncModel
aliases: Set[str]


Expand Down

0 comments on commit 44e6be1

Please sign in to comment.