diff --git a/healthchain/ai_integrations/baseintegration.py b/healthchain/ai_integrations/baseintegration.py new file mode 100644 index 0000000..1bcf1ce --- /dev/null +++ b/healthchain/ai_integrations/baseintegration.py @@ -0,0 +1,24 @@ +from typing import Optional + + +class AIIntegrationBase: + def __init__(self, model_name: str, api_key: Optional[str]): + self.model_name = model_name + self.model = self.load_model(model_name) + + def load_model(self, model_name: str, api_key: Optional[str]): + raise NotImplementedError("Subclasses should implement this method.") + + def preprocess(self, data): + raise NotImplementedError("Subclasses should implement this method.") + + def predict(self, data): + raise NotImplementedError("Subclasses should implement this method.") + + def postprocess(self, prediction): + raise NotImplementedError("Subclasses should implement this method.") + + def run(self, data): + preprocessed_data = self.preprocess(data) + prediction = self.predict(preprocessed_data) + return self.postprocess(prediction) diff --git a/healthchain/ai_integrations/openai.py b/healthchain/ai_integrations/openai.py new file mode 100644 index 0000000..0339a0d --- /dev/null +++ b/healthchain/ai_integrations/openai.py @@ -0,0 +1,35 @@ +from healthchain.ai_integrations.baseintegration import AIIntegrationBase +import openai + + +class OpenAIIntegration(AIIntegrationBase): + def __init__(self, model_name: str, api_key: str, **params): + self.api_key = api_key + self.params = params + super().__init__(model_name) + + def load_model(self, model_name: str): + openai.api_key = self.api_key + return model_name + + def preprocess(self, data): + return data + + def predict(self, data): + responses = [] + if isinstance(data, list): + for item in data: + response = openai.Completion.create( + model=self.model_name, prompt=item, **self.params + ) + responses.append(response.choices[0].text.strip()) + else: + response = openai.Completion.create( + model=self.model_name, prompt=data, **self.params + ) + responses.append(response.choices[0].text.strip()) + return responses + + def postprocess(self, prediction): + # For simplicity, return the raw prediction + return prediction