Skip to content

Commit

Permalink
Add feature to use deployments in replicate (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjfricke authored Nov 15, 2023
1 parent ef703b8 commit 39312cb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/docs/models/replicate.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ To run a [🤗 Transformers](./hf.html) model on Replicate, you need to:

1. Export the environment variable `REPLICATE_API_TOKEN` with the credential to use to authenticate the request.

2. Set the `transport=` argument to your model to `replicate:ORG/MODEL`, matching the name with which the model was uploaded.
2. Set the `endpoint=` argument to your model to `replicate:ORG/MODEL`, matching the name with which the model was uploaded. If you want to use models from your organization's deployments, set the `endpoint=` argument to your deployment to `replicate:deployment/ORG/MODEL`.

3. Set the `tokenizer=` argument to your model to a huggingface transformers name from which correct configuration for the tokenizer in use can be downloaded.

Expand Down
28 changes: 21 additions & 7 deletions src/lmql/models/lmtp/lmtp_replicate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, model_identifier, session, endpoint, **kwargs):
else: # FIXME: Allow API key to be passed in kwargs?
raise Exception('Please define REPLICATE_API_TOKEN as an environment variable to use Replicate models')

self.model_validated = False
self.use_deployment_endpoint = False

endpoint = endpoint.removeprefix('replicate:')
if len(endpoint) == 0:
endpoint = model_identifier
Expand All @@ -31,13 +34,19 @@ def __init__(self, model_identifier, session, endpoint, **kwargs):
self.model_identifier = endpoint
self.model_version = None
elif len(endpoint_pieces) == 3:
# passed a name/version pair
self.model_identifier = '/'.join(endpoint_pieces[:2])
self.model_version = endpoint_pieces[-1]
# check if it is a deployment endpoint
if endpoint_pieces[0] == 'deployment':
self.model_identifier = '/'.join(endpoint_pieces[1:3])
self.model_version = None
self.use_deployment_endpoint = True
self.model_validated = True
else:
# passed a name/version pair
self.model_identifier = '/'.join(endpoint_pieces[:2])
self.model_version = endpoint_pieces[-1]
else:
raise Exception('Unknown endpoint descriptor for replicate; should be owner/model or owner/model/version')
raise Exception('Unknown endpoint descriptor for replicate; should be owner/model, owner/model/version' or 'deployment/owner/model')

self.model_validated = False
self.session = session
self.stream_id = 0
self.handler = None
Expand Down Expand Up @@ -81,8 +90,13 @@ async def submit_batch(self, batch):
if self.model_version is None or not self.model_validated:
await self.check_model()
# FIXME: Maybe store id to use for later cancel calls?
body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True, "version": self.model_version}
async with self.session.post('https://api.replicate.com/v1/predictions',
if self.use_deployment_endpoint:
body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True}
endpoint = f'https://api.replicate.com/v1/deployments/{self.model_identifier}/predictions'
else:
body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True, "version": self.model_version}
endpoint = 'https://api.replicate.com/v1/predictions'
async with self.session.post(endpoint,
headers={
'Authorization': f'Token {self.api_key}',
'Content-Type': 'application/json'
Expand Down

0 comments on commit 39312cb

Please sign in to comment.