Skip to content

Commit

Permalink
Merge pull request #211 from eth-sri/gpt-3.5-turbo-instruct
Browse files Browse the repository at this point in the history
Add support for new gpt-3.5-turbo-instruct model
  • Loading branch information
lbeurerkellner authored Sep 19, 2023
2 parents 5a9af64 + 624b6d6 commit 4b05aa7
Show file tree
Hide file tree
Showing 27 changed files with 349 additions and 172 deletions.
6 changes: 3 additions & 3 deletions src/lmql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from lmql.runtime.output_writer import headless, printing, silent, stream
from lmql.runtime.interpreter import LMQLResult

from lmql.models.model import model, LMQLModel
from lmql.models.model import model, LMQLModelDescriptor
from lmql.runtime.loop import main
import lmql.runtime.decorators as decorators

Expand All @@ -44,9 +44,9 @@ def autoconnect():
def set_backend(backend):
model_registry.backend_configuration = backend

def set_default_model(model: Union[str, LMQLModel]):
def set_default_model(model: Union[str, LMQLModelDescriptor]):
"""
Sets the model instance to be used when no 'from' clause or @lmql.query(model=<model>) are used.
Sets the model instance to be used when no 'from' clause or @lmql.query(model=<model>) are specified.
This applies globally in the current process.
"""
Expand Down
2 changes: 0 additions & 2 deletions src/lmql/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ def hello():
asyncio.run(lmql.run(code_local, output_writer=lmql.printing))

if backend is None or backend == "openai":
import lmql.runtime.dclib as dc
dc.clear_tokenizer()
print("[Greeting OpenAI]")
code_openai = 'argmax "Hello[WHO]" from "openai/text-ada-001" where len(TOKENS(WHO)) < 10 and not "\\n" in WHO'
asyncio.run(lmql.run(code_openai, output_writer=lmql.printing, model="openai/text-ada-001"))
Expand Down
9 changes: 6 additions & 3 deletions src/lmql/lib/actions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import lmql
import ast

from .prompts.wiki_prompt import EXAMPLES as WIKI_EXAMPLES
from .prompts.inline_use import INLINE_USE_PROMPT
from .prompts.inline_code import INLINE_CODE_PROMPT

import warnings

from dataclasses import dataclass

async def wiki(q: str, lookup = None):
Expand Down Expand Up @@ -95,6 +97,8 @@ async def fct_call(fcts):
print("unknown action", [action], list(action_fcts.keys()))
" Unknown action: {action}{DELIMITER_END}"
result = ""
if CALL.lstrip().startswith("<<"):
warnings.warn("Detected a rare failure case where LMQL Actions function calling does not work as intended. Please share your query with the LMQL developers, so we can fix this.", RuntimeWarning)
return "(error)"
else:
try:
Expand All @@ -113,8 +117,7 @@ async def inline_segment(fcts):
if not SEGMENT.endswith(DELIMITER):
return SEGMENT
else:
"[CALL]" where fct_call(CALL, fcts) and len(TOKENS(CALL)) > 0
result = CALL.split("|", 1)[1]
"[CALL: fct_call(fcts)]" where len(TOKENS(CALL)) > 0
return SEGMENT[:-len(DELIMITER)] + CALL + DELIMITER_END
'''

Expand Down
4 changes: 2 additions & 2 deletions src/lmql/lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def extract_json(s):
@lmql.query
async def single_shot_as_type(s, ty, model="chatgpt"):
'''lmql
argmax(openai_chunksize=1024)
argmax(chunksize=1024)
schema_description = type_schema_description(ty)
"Provided a data schema of the following schema: {schema_description}\n"
"Translate the following into a JSON payload: {s}\n"
Expand Down Expand Up @@ -147,7 +147,7 @@ async def is_type(ty, description=False):
"{existing_value}{line_end}"
else:
# Chat API models do not support advanced integer constraints
if "turbo" in context.interpreter.model_identifier or "gpt-4" in context.interpreter.model_identifier:
if context.interpreter.model_identifier.endswith("-turbo") or "gpt-4" in context.interpreter.model_identifier:
"[INT_VALUE]" where STOPS_AT(INT_VALUE, ",") and len(TOKENS(INT_VALUE)) < 4
if line_end.startswith(",") and not INT_VALUE.endswith(","):
","
Expand Down
6 changes: 5 additions & 1 deletion src/lmql/models/lmtp/backends/random_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ def __init__(self, seed=None, vocab=None, **kwargs):
self.seed = seed
self.kwargs = kwargs

if kwargs.get("verbose", False):
print("['random' model using seed {}]".format(seed))

if vocab is not None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(vocab)
print("['random' model using tokenizer {}]".format(tokenizer))
if kwargs.get("verbose", False):
print("['random' model using tokenizer {}]".format(tokenizer))
self._eos_token_id = tokenizer.eos_token_id
self._vocab_size = tokenizer.vocab_size
else:
Expand Down
9 changes: 7 additions & 2 deletions src/lmql/models/lmtp/backends/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from lmql.models.lmtp.backends.lmtp_model import LMTPModel, LMTPModelResult, TokenStreamer
import numpy as np

def format_call(model_name, **kwargs):
if len(kwargs) == 0:
return f'"{model_name}"'
return f'"{model_name}", {", ".join([f"{k}={v}" for k, v in kwargs.items()])}'

class TransformersLLM(LMTPModel):
def __init__(self, model_identifier, **kwargs):
self.model_identifier = model_identifier
Expand All @@ -12,11 +17,11 @@ def __init__(self, model_identifier, **kwargs):

if self.model_args.pop("loader", None) == "auto-gptq":
from auto_gptq import AutoGPTQForCausalLM
print("[Loading", self.model_identifier, "with", f"AutoGPTQForCausalLM.from_quantized({self.model_identifier}, {str(self.model_args)[1:-1]})]", flush=True)
print("[Loading", self.model_identifier, "with", "AutoGPTQForCausalLM.from_quantized({})]".format(format_call(self.model_identifier, **self.model_args)), flush=True)
self.model = AutoGPTQForCausalLM.from_quantized(self.model_identifier, **self.model_args)
else:
from transformers import AutoModelForCausalLM
print("[Loading", self.model_identifier, "with", f"AutoModelForCausalLM.from_pretrained({self.model_identifier}, {str(self.model_args)[1:-1]})]", flush=True)
print("[Loading", self.model_identifier, "with", "AutoModelForCausalLM.from_pretrained({})]".format(format_call(self.model_identifier, **self.model_args)), flush=True)
self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, **self.model_args)

print("[", self.model_identifier, " ready on device ", self.model.device,
Expand Down
14 changes: 5 additions & 9 deletions src/lmql/models/lmtp/lmtp_dcmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import lmql.utils.nputil as nputil
import lmql.runtime.masks as masks
from lmql.runtime.token_distribution import TokenDistribution
from lmql.models.model import LMQLModel

from typing import Any, List, Union
from typing import Any, List, Union, Type
import random
import sys
import traceback
Expand Down Expand Up @@ -467,11 +468,11 @@ def __del__(self):
if "no current event loop" in str(e): pass
else: raise e

def __call__(self):
def __call__(self) -> LMQLModel:
# reference to factory instance
this = self

class LMTPDcModelCls:
class LMTPAdapterModel(LMQLModel):
def __init__(self) -> None:
self.model_identifier = this.model_identifier
self.served_model = None
Expand All @@ -491,8 +492,6 @@ def get_dclib_model(self):
bos_token_id = self.get_tokenizer().bos_token_id
eos_token_id = self.get_tokenizer().eos_token_id

dc.set_dclib_tokenizer(self.get_tokenizer())

inprocess_client_constructor = None

if this.inprocess:
Expand All @@ -512,8 +511,5 @@ async def detokenize(self, input_ids):

def sync_tokenize(self, text):
return self.get_tokenizer()(text)["input_ids"]

def report_metrics(self, metrics):
pass

return LMTPDcModelCls()
return LMTPAdapterModel()
93 changes: 75 additions & 18 deletions src/lmql/models/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from lmql.runtime.tokenizer import LMQLTokenizer
from lmql.runtime.dclib.dclib_model import DcModel
from typing import Any
from abc import ABC, abstractmethod
from lmql.models.lmtp.utils import rename_model_args

import warnings

class LMQLModel:
class LMQLModelDescriptor:
"""
Descriptor base class for results of lmql.model(...) calls.
The actual model runs either remotely (lmql serve-model), in-process via self.model or
via an API (e.g. OpenAI).
is accessed via API (e.g. OpenAI).
Use LMQLModelRegistry.get to resolve a descriptor into a usable LMQLModel.
"""

def __init__(self, model_identifier, model=None, **kwargs):
Expand All @@ -20,13 +27,57 @@ def __repr__(self) -> str:
return str(self)

def __str__(self):
return "<LMQLModel: {}>".format(self.model_identifier)
return "<LMQLModelDescriptor: {}>".format(self.model_identifier)

LMQLModel.inprocess_instances = {}
LMQLModelDescriptor.inprocess_instances = {}

class LMQLModel(ABC):
"""
Abstract base class for models (interface to integrate LMTP or OpenAI API)
"""

@abstractmethod
def get_tokenizer(self) -> LMQLTokenizer:
"""
Returns the tokenizer used by this model.
"""
raise NotImplementedError()

@abstractmethod
def get_dclib_model(self) -> DcModel:
"""
Returns the dclib DcModel handle to use this model.
This handle is used for token generation and decoding via
its methods like `argmax`, `sample` and `score`.
"""
raise NotImplementedError()

@abstractmethod
async def tokenize(self, text: str) -> Any:
"""
Tokenizes the given text and returns the tokenized input_ids in
the format expected by the model.
"""
raise NotImplementedError()

@abstractmethod
async def detokenize(self, input_ids: Any) -> str:
"""
Detokenizes the given input_ids and returns the detokenized text.
"""
raise NotImplementedError()

@abstractmethod
def sync_tokenize(self, text):
"""
Synchroneous version of `tokenize`.
"""
raise NotImplementedError()

def inprocess(model_name, use_existing_configuration=False, **kwargs):
"""
Loads a 'transformers' model in-process and returns an LMQLModel object
Loads a 'transformers' model in-process and returns an LMQLModelDescriptor object
to use this in-process model in LMQL.
This is useful when you don't want to spawn a separate 'lmql serve-model' process.
Expand Down Expand Up @@ -67,23 +118,29 @@ def inprocess(model_name, use_existing_configuration=False, **kwargs):
else:
cmdline_args += f"--{k} {v} "

if cmdline_args in LMQLModel.inprocess_instances.keys():
warnings.warn("info: reusing existing in-process model.")
model = LMQLModel.inprocess_instances[cmdline_args]
return LMQLModel(model_name, model=model)
if cmdline_args in LMQLModelDescriptor.inprocess_instances.keys():
model = LMQLModelDescriptor.inprocess_instances[cmdline_args]
return LMQLModelDescriptor(model_name, model=model)

if use_existing_configuration:
# find existing match for model_name only
for cmdargs, p in LMQLModel.inprocess_instances.items():
if cmdargs.split(" ")[0] == model_name:
return LMQLModel(model_name, model=p)

kwargs["inprocess"] = True
model = lmtp_model(model_name, **kwargs)
LMQLModel.inprocess_instances[cmdline_args] = model
return LMQLModel(model_name, model=model)
LMQLModelDescriptor.inprocess_instances[cmdline_args] = model
return LMQLModelDescriptor(model_name, model=model)

def model(model_identifier, **kwargs):
"""
Constructs an LMQL model descriptor object to be used in
a `from` clause or as `model=<MODEL>` argument to @lmql.query(...).
Examples:
lmql.model("openai/gpt-3.5-turbo-instruct") # OpenAI API model
lmql.model("random", seed=123) # randomly sampling model
lmql.model("llama.cpp:<YOUR_WEIGHTS>.bin") # llama.cpp model
lmql.model("local:gpt2") # load a `transformers` model in process
lmql.model("local:gpt2", cuda=True, load_in_4bit=True) # load a `transformers` model in process with additional arguments
"""
# handle inprocess models
is_inprocess = kwargs.pop("inprocess", False) or model_identifier.startswith("local:")
if is_inprocess and model_identifier.startswith("local:"):
Expand All @@ -92,4 +149,4 @@ def model(model_identifier, **kwargs):
if is_inprocess:
return inprocess(model_identifier, **kwargs)
else:
return LMQLModel(model_identifier, **kwargs)
return LMQLModelDescriptor(model_identifier, **kwargs)
20 changes: 20 additions & 0 deletions src/lmql/models/model_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
This file contains a list of hard-coded model information
for LMQL to enable ready-to-use configuration for models
that have been tested and verified to work with LMQL.
"""
from dataclasses import dataclass

@dataclass
class ModelInfo:
is_chat_model: bool = False

def model_info(model_identifier):
if model_identifier == "openai/gpt-3.5-turbo-instruct" or model_identifier == "gpt-3.5-turbo-instruct":
return ModelInfo(is_chat_model=False)
elif model_identifier == "openai/gpt-4":
return ModelInfo(is_chat_model=True)
elif "gpt-3.5-turbo" in model_identifier:
return ModelInfo(is_chat_model=True)
else:
return ModelInfo(is_chat_model=False)
3 changes: 2 additions & 1 deletion src/lmql/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lmql.ops.inline_call import InlineCallOp
from lmql.ops.booleans import *
from lmql.ops.regex import Regex
from lmql.models.model_info import model_info

lmql_operation_registry = {}

Expand Down Expand Up @@ -193,7 +194,7 @@ def follow(self, v, **kwargs):
("*", False)
)

if "turbo" in context.runtime.model_identifier or "gpt-4" in context.runtime.model_identifier:
if model_info(context.runtime.model_identifier).is_chat_model:
if not all([c in "0123456789" for c in v]):
return fmap(
("*", False)
Expand Down
Loading

0 comments on commit 4b05aa7

Please sign in to comment.