From 87a808dc8ad695d2524263456b71b4cc25e7fed6 Mon Sep 17 00:00:00 2001 From: Adam Kells Date: Thu, 20 Jun 2024 08:58:18 +0100 Subject: [PATCH 1/2] base integration and openai --- .../ai_integrations/baseintegration.py | 24 +++++++++++++ healthchain/ai_integrations/openai.py | 35 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 healthchain/ai_integrations/baseintegration.py create mode 100644 healthchain/ai_integrations/openai.py diff --git a/healthchain/ai_integrations/baseintegration.py b/healthchain/ai_integrations/baseintegration.py new file mode 100644 index 0000000..1bcf1ce --- /dev/null +++ b/healthchain/ai_integrations/baseintegration.py @@ -0,0 +1,24 @@ +from typing import Optional + + +class AIIntegrationBase: + def __init__(self, model_name: str, api_key: Optional[str]): + self.model_name = model_name + self.model = self.load_model(model_name) + + def load_model(self, model_name: str, api_key: Optional[str]): + raise NotImplementedError("Subclasses should implement this method.") + + def preprocess(self, data): + raise NotImplementedError("Subclasses should implement this method.") + + def predict(self, data): + raise NotImplementedError("Subclasses should implement this method.") + + def postprocess(self, prediction): + raise NotImplementedError("Subclasses should implement this method.") + + def run(self, data): + preprocessed_data = self.preprocess(data) + prediction = self.predict(preprocessed_data) + return self.postprocess(prediction) diff --git a/healthchain/ai_integrations/openai.py b/healthchain/ai_integrations/openai.py new file mode 100644 index 0000000..1b61fb7 --- /dev/null +++ b/healthchain/ai_integrations/openai.py @@ -0,0 +1,35 @@ +from healthchain.ai_integrations.baseintegration import AIIntegrationBase +import openai + + +class OpenAIIntegration(AIIntegrationBase): + def __init__(self, model_name: str, api_key: str, **params): + self.api_key = api_key + self.params = params + super().__init__(model_name) + + def load_model(self, model_name: str): + openai.api_key = self.api_key + return model_name + + def preprocess(self, data): + return data + + def predict(self, data): + responses = [] + if isinstance(data, list): + for item in data: + response = openai.Completion.create( + model=self.model_name, prompt=item, **self.params + ) + responses.append(response.choices[0].text.strip()) + else: + response = openai.Completion.create( + model=self.model_name, prompt=data, max_tokens=50 + ) + responses.append(response.choices[0].text.strip()) + return responses + + def postprocess(self, prediction): + # For simplicity, return the raw prediction + return prediction From 0f48aa350151e9a1818e839a9d3bcbdfcf41c47b Mon Sep 17 00:00:00 2001 From: Adam Kells Date: Thu, 20 Jun 2024 08:58:26 +0100 Subject: [PATCH 2/2] base integration and openai --- healthchain/ai_integrations/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/healthchain/ai_integrations/openai.py b/healthchain/ai_integrations/openai.py index 1b61fb7..0339a0d 100644 --- a/healthchain/ai_integrations/openai.py +++ b/healthchain/ai_integrations/openai.py @@ -25,7 +25,7 @@ def predict(self, data): responses.append(response.choices[0].text.strip()) else: response = openai.Completion.create( - model=self.model_name, prompt=data, max_tokens=50 + model=self.model_name, prompt=data, **self.params ) responses.append(response.choices[0].text.strip()) return responses