diff --git a/data_gemma/openai_api.py b/data_gemma/openai_api.py index a81a12d..8ded65e 100644 --- a/data_gemma/openai_api.py +++ b/data_gemma/openai_api.py @@ -30,11 +30,13 @@ class OpenAI(base.LLM): def __init__( self, model: str, + base_url: str, api_key: str, verbose: bool = True, session: requests.Session | None = None, ): self.key = api_key + self.base_url = base_url if not session: session = requests.Session() self.session: requests.Session = session @@ -86,7 +88,7 @@ def _call_api(self, req_data: str) -> Any: 'Authorization': f'Bearer {self.key}', } r = self.session.post( - 'https://api.openai.com/v1/chat/completions', + self.base_url, data=req_data, headers=header, )