Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XAI Support #86

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/beyondllm/llms/xAi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from beyondllm.llms.base import BaseLLMModel, ModelConfig
from typing import Any, Dict
from dataclasses import dataclass, field
import os

@dataclass
class XAiModel:
"""
Class representing a Language Model (LLM) model using XAi.

Example:
```
>>> llm = XAiModel(api_key="<your_api_key>", model_kwargs={"temperature": 0.5})
```
or
```
>>> import os
>>> os.environ['XAi_API_KEY'] = "***********" #replace with your key
>>> llm = XAiModel()
```
"""
api_key: str =" "
model_name: str = "command-r-plus-08-2024"
model_kwargs: dict = field(default_factory=lambda: {
"temperature": 0.5,
"top_p": 1,
"max_tokens": 2048,
})

def __post_init__(self):
if not self.api_key:
self.api_key = os.getenv('XAi_API_KEY')
if not self.api_key:
raise ValueError("XAi_API_KEY is not provided and not found in environment variables.")
self.load_llm()

def load_llm(self):
try:
import XAi
except ImportError:
print("The XAi module is not installed. Please install it with 'pip install XAi'.")

try:
self.client = XAi.ClientV2(api_key=self.api_key)
except Exception as e:
raise Exception(f"Failed to initialize XAi client: {str(e)}")
Comment on lines +38 to +46
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Ensure that the load_llm method checks if the XAi module is successfully imported before attempting to use it, to prevent potential runtime errors. [possible bug]

Suggested change
try:
import XAi
except ImportError:
print("The XAi module is not installed. Please install it with 'pip install XAi'.")
try:
self.client = XAi.ClientV2(api_key=self.api_key)
except Exception as e:
raise Exception(f"Failed to initialize XAi client: {str(e)}")
try:
import XAi
except ImportError:
print("The XAi module is not installed. Please install it with 'pip install XAi'.")
return
try:
self.client = XAi.ClientV2(api_key=self.api_key)
except Exception as e:
raise Exception(f"Failed to initialize XAi client: {str(e)}")


def predict(self, prompt: Any) -> str:
try:
response = self.client.chat(
model=self.model_name,
messages=[{"role": "user", "content": prompt}]
)
return response.message.content[0].text
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Handle the case where response.message.content might be empty or not structured as expected to prevent potential index errors in the predict method. [possible issue]

Suggested change
return response.message.content[0].text
if response.message.content:
return response.message.content[0].text
else:
raise Exception("Received an empty response from the model.")

except Exception as e:
raise Exception(f"Failed to generate prediction: {str(e)}")

@staticmethod
def load_from_kwargs(self, kwargs: Dict):
model_config = ModelConfig(**kwargs)
self.config = model_config
self.load_llm()

Comment on lines +58 to +63
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Modify the load_from_kwargs method to avoid using self as a parameter in a static method, as it is not appropriate and can cause confusion. [best practice]

Suggested change
@staticmethod
def load_from_kwargs(self, kwargs: Dict):
model_config = ModelConfig(**kwargs)
self.config = model_config
self.load_llm()
@staticmethod
def load_from_kwargs(kwargs: Dict):
model_config = ModelConfig(**kwargs)
instance = XAiModel()
instance.config = model_config
instance.load_llm()

Loading