-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add XAI Support #86
base: main
Are you sure you want to change the base?
Add XAI Support #86
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,63 @@ | ||||||||||||||||||||||||||
from beyondllm.llms.base import BaseLLMModel, ModelConfig | ||||||||||||||||||||||||||
from typing import Any, Dict | ||||||||||||||||||||||||||
from dataclasses import dataclass, field | ||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@dataclass | ||||||||||||||||||||||||||
class XAiModel: | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
Class representing a Language Model (LLM) model using XAi. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Example: | ||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||
>>> llm = XAiModel(api_key="<your_api_key>", model_kwargs={"temperature": 0.5}) | ||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||
or | ||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||
>>> import os | ||||||||||||||||||||||||||
>>> os.environ['XAi_API_KEY'] = "***********" #replace with your key | ||||||||||||||||||||||||||
>>> llm = XAiModel() | ||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
api_key: str =" " | ||||||||||||||||||||||||||
model_name: str = "command-r-plus-08-2024" | ||||||||||||||||||||||||||
model_kwargs: dict = field(default_factory=lambda: { | ||||||||||||||||||||||||||
"temperature": 0.5, | ||||||||||||||||||||||||||
"top_p": 1, | ||||||||||||||||||||||||||
"max_tokens": 2048, | ||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def __post_init__(self): | ||||||||||||||||||||||||||
if not self.api_key: | ||||||||||||||||||||||||||
self.api_key = os.getenv('XAi_API_KEY') | ||||||||||||||||||||||||||
if not self.api_key: | ||||||||||||||||||||||||||
raise ValueError("XAi_API_KEY is not provided and not found in environment variables.") | ||||||||||||||||||||||||||
self.load_llm() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def load_llm(self): | ||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||
import XAi | ||||||||||||||||||||||||||
except ImportError: | ||||||||||||||||||||||||||
print("The XAi module is not installed. Please install it with 'pip install XAi'.") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||
self.client = XAi.ClientV2(api_key=self.api_key) | ||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||
raise Exception(f"Failed to initialize XAi client: {str(e)}") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def predict(self, prompt: Any) -> str: | ||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||
response = self.client.chat( | ||||||||||||||||||||||||||
model=self.model_name, | ||||||||||||||||||||||||||
messages=[{"role": "user", "content": prompt}] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
return response.message.content[0].text | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Handle the case where
Suggested change
|
||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||
raise Exception(f"Failed to generate prediction: {str(e)}") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||||
def load_from_kwargs(self, kwargs: Dict): | ||||||||||||||||||||||||||
model_config = ModelConfig(**kwargs) | ||||||||||||||||||||||||||
self.config = model_config | ||||||||||||||||||||||||||
self.load_llm() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Comment on lines
+58
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Modify the
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Ensure that the
load_llm
method checks if theXAi
module is successfully imported before attempting to use it, to prevent potential runtime errors. [possible bug]