-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #229 from leeeizhang/lei/seperate-models
[MRG] seperate LLM models
- Loading branch information
Showing
8 changed files
with
620 additions
and
595 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from .anthropic import * | ||
from .deepseek import * | ||
from .mistral import * | ||
from .ollama import * | ||
from .openai import * | ||
|
||
from mle.utils import get_config | ||
|
||
|
||
MODEL_OLLAMA = 'Ollama' | ||
MODEL_OPENAI = 'OpenAI' | ||
MODEL_CLAUDE = 'Claude' | ||
MODEL_MISTRAL = 'MistralAI' | ||
MODEL_DEEPSEEK = 'DeepSeek' | ||
|
||
|
||
class ObservableModel: | ||
""" | ||
A class that wraps a model to make it trackable by the metric platform (e.g., Langfuse). | ||
""" | ||
|
||
try: | ||
from mle.utils import get_langfuse_observer | ||
_observe = get_langfuse_observer() | ||
except Exception as e: | ||
# If importing fails, set _observe to a lambda function that does nothing. | ||
_observe = lambda fn: fn | ||
|
||
def __init__(self, model: Model): | ||
""" | ||
Initialize the ObservableModel. | ||
Args: | ||
model: The model to be wrapped and made observable. | ||
""" | ||
self.model = model | ||
|
||
@_observe | ||
def query(self, *args, **kwargs): | ||
return self.model.query(*args, **kwargs) | ||
|
||
@_observe | ||
def stream(self, *args, **kwargs): | ||
return self.model.query(*args, **kwargs) | ||
|
||
|
||
def load_model(project_dir: str, model_name: str=None, observable=True): | ||
""" | ||
load_model: load the model based on the configuration. | ||
Args: | ||
project_dir (str): The project directory. | ||
model_name (str): The model name. | ||
observable (boolean): Whether the model should be tracked. | ||
""" | ||
config = get_config(project_dir) | ||
model = None | ||
|
||
if config['platform'] == MODEL_OLLAMA: | ||
model = OllamaModel(model=model_name) | ||
if config['platform'] == MODEL_OPENAI: | ||
model = OpenAIModel(api_key=config['api_key'], model=model_name) | ||
if config['platform'] == MODEL_CLAUDE: | ||
model = ClaudeModel(api_key=config['api_key'], model=model_name) | ||
if config['platform'] == MODEL_MISTRAL: | ||
model = MistralModel(api_key=config['api_key'], model=model_name) | ||
if config['platform'] == MODEL_DEEPSEEK: | ||
model = DeepSeekModel(api_key=config['api_key'], model=model_name) | ||
|
||
if observable: | ||
return ObservableModel(model) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import importlib.util | ||
|
||
from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name | ||
from mle.model.common import Model | ||
|
||
|
||
class ClaudeModel(Model): | ||
def __init__(self, api_key, model, temperature=0.7): | ||
""" | ||
Initialize the Claude model. | ||
Args: | ||
api_key (str): The Anthropic API key. | ||
model (str): The model with version. | ||
temperature (float): The temperature value. | ||
""" | ||
super().__init__() | ||
|
||
dependency = "anthropic" | ||
spec = importlib.util.find_spec(dependency) | ||
if spec is not None: | ||
self.anthropic = importlib.import_module(dependency).Anthropic | ||
else: | ||
raise ImportError( | ||
"It seems you didn't install anthropic. In order to enable the OpenAI client related features, " | ||
"please make sure openai Python package has been installed. " | ||
"More information, please refer to: https://docs.anthropic.com/en/api/client-sdks" | ||
) | ||
|
||
self.model = model if model else 'claude-3-5-sonnet-20240620' | ||
self.model_type = 'Claude' | ||
self.temperature = temperature | ||
self.client = self.anthropic(api_key=api_key) | ||
self.func_call_history = [] | ||
|
||
@staticmethod | ||
def _add_tool_result_into_chat_history(chat_history, func, result): | ||
""" | ||
Add the result of tool calls into messages. | ||
""" | ||
return chat_history.extend([ | ||
{ | ||
"role": "assistant", | ||
"content": [ | ||
{ | ||
"type": "tool_use", | ||
"id": func.id, | ||
"name": func.name, | ||
"input": func.input, | ||
}, | ||
] | ||
}, | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "tool_result", | ||
"tool_use_id": func.id, | ||
"content": result, | ||
}, | ||
] | ||
}, | ||
]) | ||
|
||
def query(self, chat_history, **kwargs): | ||
""" | ||
Query the LLM model. | ||
Args: | ||
chat_history: The context (chat history). | ||
""" | ||
# claude has not system role in chat_history | ||
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts | ||
system_prompt = "" | ||
for idx, msg in enumerate(chat_history): | ||
if msg["role"] == "system": | ||
system_prompt += msg["content"] | ||
|
||
# claude does not support mannual `response_format`, so we append it into system prompt | ||
if "response_format" in kwargs.keys(): | ||
system_prompt += ( | ||
f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words" | ||
) | ||
|
||
# mapping the openai function_schema to claude tool_schema | ||
tools = kwargs.get("functions",[]) | ||
for tool in tools: | ||
if "parameters" in tool.keys(): | ||
tool["input_schema"] = tool["parameters"] | ||
del tool["parameters"] | ||
|
||
completion = self.client.messages.create( | ||
max_tokens=4096, | ||
model=self.model, | ||
system=system_prompt, | ||
messages=[msg for msg in chat_history if msg["role"] != "system"], | ||
temperature=self.temperature, | ||
stream=False, | ||
tools=tools, | ||
) | ||
if completion.stop_reason == "tool_use": | ||
for func in completion.content: | ||
if func.type != "tool_use": | ||
continue | ||
function_name = process_function_name(func.name) | ||
arguments = func.input | ||
print("[MLE FUNC CALL]: ", function_name) | ||
self.func_call_history.append({"name": function_name, "arguments": arguments}) | ||
# avoid the multiple search function calls | ||
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] | ||
if len(search_attempts) > 3: | ||
kwargs['functions'] = [] | ||
result = get_function(function_name)(**arguments) | ||
self._add_tool_result_into_chat_history(chat_history, func, result) | ||
return self.query(chat_history, **kwargs) | ||
else: | ||
return completion.content[0].text | ||
|
||
def stream(self, chat_history, **kwargs): | ||
""" | ||
Stream the output from the LLM model. | ||
Args: | ||
chat_history: The context (chat history). | ||
""" | ||
# claude has not system role in chat_history | ||
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts | ||
system_prompt = "" | ||
for idx, msg in enumerate(chat_history): | ||
if msg["role"] == "system": | ||
system_prompt += msg["content"] | ||
chat_history = [msg for msg in chat_history if msg["role"] != "system"] | ||
|
||
# claude does not support mannual `response_format`, so we append it into system prompt | ||
if "response_format" in kwargs.keys(): | ||
system_prompt += ( | ||
f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words" | ||
) | ||
|
||
with self.client.messages.stream( | ||
max_tokens=4096, | ||
model=self.model, | ||
messages=chat_history, | ||
) as stream: | ||
for chunk in stream.text_stream: | ||
yield chunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class Model(ABC): | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the model. | ||
""" | ||
self.model_type = None | ||
|
||
@abstractmethod | ||
def query(self, chat_history, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def stream(self, chat_history, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import importlib.util | ||
import json | ||
|
||
from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name | ||
from mle.model.common import Model | ||
|
||
|
||
class DeepSeekModel(Model): | ||
def __init__(self, api_key, model, temperature=0.7): | ||
""" | ||
Initialize the DeepSeek model. | ||
Args: | ||
api_key (str): The DeepSeek API key. | ||
model (str): The model with version. | ||
temperature (float): The temperature value. | ||
""" | ||
super().__init__() | ||
|
||
dependency = "openai" | ||
spec = importlib.util.find_spec(dependency) | ||
if spec is not None: | ||
self.openai = importlib.import_module(dependency).OpenAI | ||
else: | ||
raise ImportError( | ||
"It seems you didn't install openai. In order to enable the OpenAI client related features, " | ||
"please make sure openai Python package has been installed. " | ||
"More information, please refer to: https://openai.com/product" | ||
) | ||
self.model = model if model else "deepseek-coder" | ||
self.model_type = 'DeepSeek' | ||
self.temperature = temperature | ||
self.client = self.openai( | ||
api_key=api_key, base_url="https://api.deepseek.com/beta" | ||
) | ||
self.func_call_history = [] | ||
|
||
def _convert_functions_to_tools(self, functions): | ||
""" | ||
Convert OpenAI-style functions to DeepSeek-style tools. | ||
""" | ||
tools = [] | ||
for func in functions: | ||
tool = { | ||
"type": "function", | ||
"function": { | ||
"name": func["name"], | ||
"description": func.get("description", ""), | ||
"parameters": func["parameters"], | ||
}, | ||
} | ||
tools.append(tool) | ||
return tools | ||
|
||
def query(self, chat_history, **kwargs): | ||
""" | ||
Query the LLM model. | ||
Args: | ||
chat_history: The context (chat history). | ||
""" | ||
functions = kwargs.get("functions", None) | ||
tools = self._convert_functions_to_tools(functions) if functions else None | ||
parameters = kwargs | ||
completion = self.client.chat.completions.create( | ||
model=self.model, | ||
messages=chat_history, | ||
temperature=self.temperature, | ||
stream=False, | ||
tools=tools, | ||
**parameters, | ||
) | ||
|
||
resp = completion.choices[0].message | ||
if resp.tool_calls: | ||
for tool_call in resp.tool_calls: | ||
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False}) | ||
function_name = process_function_name(tool_call.function.name) | ||
arguments = json.loads(tool_call.function.arguments) | ||
print("[MLE FUNC CALL]: ", function_name) | ||
self.func_call_history.append({"name": function_name, "arguments": arguments}) | ||
# avoid the multiple search function calls | ||
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] | ||
if len(search_attempts) > 3: | ||
parameters['tool_choice'] = "none" | ||
result = get_function(function_name)(**arguments) | ||
chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id}) | ||
return self.query(chat_history, **parameters) | ||
else: | ||
return resp.content | ||
|
||
def stream(self, chat_history, **kwargs): | ||
""" | ||
Stream the output from the LLM model. | ||
Args: | ||
chat_history: The context (chat history). | ||
""" | ||
arguments = "" | ||
function_name = "" | ||
for chunk in self.client.chat.completions.create( | ||
model=self.model, | ||
messages=chat_history, | ||
temperature=self.temperature, | ||
stream=True, | ||
**kwargs, | ||
): | ||
if chunk.choices[0].delta.tool_calls: | ||
tool_call = chunk.choices[0].delta.tool_calls[0] | ||
if tool_call.function.name: | ||
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False}) | ||
function_name = process_function_name(tool_call.function.name) | ||
arguments = json.loads(tool_call.function.arguments) | ||
result = get_function(function_name)(**arguments) | ||
chat_history.append({"role": "tool", "content": result, "name": function_name}) | ||
yield from self.stream(chat_history, **kwargs) | ||
else: | ||
yield chunk.choices[0].delta.content |
Oops, something went wrong.