diff --git a/docs/docs/models/replicate.md b/docs/docs/models/replicate.md index 26b74874..e3ea9de8 100644 --- a/docs/docs/models/replicate.md +++ b/docs/docs/models/replicate.md @@ -10,7 +10,7 @@ To run a [🤗 Transformers](./hf.html) model on Replicate, you need to: 1. Export the environment variable `REPLICATE_API_TOKEN` with the credential to use to authenticate the request. -2. Set the `transport=` argument to your model to `replicate:ORG/MODEL`, matching the name with which the model was uploaded. +2. Set the `endpoint=` argument to your model to `replicate:ORG/MODEL`, matching the name with which the model was uploaded. If you want to use models from your organization's deployments, set the `endpoint=` argument to your deployment to `replicate:deployment/ORG/MODEL`. 3. Set the `tokenizer=` argument to your model to a huggingface transformers name from which correct configuration for the tokenizer in use can be downloaded. diff --git a/src/lmql/models/lmtp/lmtp_replicate_client.py b/src/lmql/models/lmtp/lmtp_replicate_client.py index d3282747..f18827b2 100644 --- a/src/lmql/models/lmtp/lmtp_replicate_client.py +++ b/src/lmql/models/lmtp/lmtp_replicate_client.py @@ -21,6 +21,9 @@ def __init__(self, model_identifier, session, endpoint, **kwargs): else: # FIXME: Allow API key to be passed in kwargs? raise Exception('Please define REPLICATE_API_TOKEN as an environment variable to use Replicate models') + self.model_validated = False + self.use_deployment_endpoint = False + endpoint = endpoint.removeprefix('replicate:') if len(endpoint) == 0: endpoint = model_identifier @@ -31,13 +34,19 @@ def __init__(self, model_identifier, session, endpoint, **kwargs): self.model_identifier = endpoint self.model_version = None elif len(endpoint_pieces) == 3: - # passed a name/version pair - self.model_identifier = '/'.join(endpoint_pieces[:2]) - self.model_version = endpoint_pieces[-1] + # check if it is a deployment endpoint + if endpoint_pieces[0] == 'deployment': + self.model_identifier = '/'.join(endpoint_pieces[1:3]) + self.model_version = None + self.use_deployment_endpoint = True + self.model_validated = True + else: + # passed a name/version pair + self.model_identifier = '/'.join(endpoint_pieces[:2]) + self.model_version = endpoint_pieces[-1] else: - raise Exception('Unknown endpoint descriptor for replicate; should be owner/model or owner/model/version') + raise Exception('Unknown endpoint descriptor for replicate; should be owner/model, owner/model/version' or 'deployment/owner/model') - self.model_validated = False self.session = session self.stream_id = 0 self.handler = None @@ -81,8 +90,13 @@ async def submit_batch(self, batch): if self.model_version is None or not self.model_validated: await self.check_model() # FIXME: Maybe store id to use for later cancel calls? - body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True, "version": self.model_version} - async with self.session.post('https://api.replicate.com/v1/predictions', + if self.use_deployment_endpoint: + body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True} + endpoint = f'https://api.replicate.com/v1/deployments/{self.model_identifier}/predictions' + else: + body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True, "version": self.model_version} + endpoint = 'https://api.replicate.com/v1/predictions' + async with self.session.post(endpoint, headers={ 'Authorization': f'Token {self.api_key}', 'Content-Type': 'application/json'