Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] seperate LLM models #229

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
595 changes: 0 additions & 595 deletions mle/model.py

This file was deleted.

70 changes: 70 additions & 0 deletions mle/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from .anthropic import *
from .deepseek import *
from .mistral import *
from .ollama import *
from .openai import *

from mle.utils import get_config


MODEL_OLLAMA = 'Ollama'
MODEL_OPENAI = 'OpenAI'
MODEL_CLAUDE = 'Claude'
MODEL_MISTRAL = 'MistralAI'
MODEL_DEEPSEEK = 'DeepSeek'


class ObservableModel:
"""
A class that wraps a model to make it trackable by the metric platform (e.g., Langfuse).
"""

try:
from mle.utils import get_langfuse_observer
_observe = get_langfuse_observer()
except Exception as e:
# If importing fails, set _observe to a lambda function that does nothing.
_observe = lambda fn: fn

def __init__(self, model: Model):
"""
Initialize the ObservableModel.
Args:
model: The model to be wrapped and made observable.
"""
self.model = model

@_observe
def query(self, *args, **kwargs):
return self.model.query(*args, **kwargs)

@_observe
def stream(self, *args, **kwargs):
return self.model.query(*args, **kwargs)


def load_model(project_dir: str, model_name: str=None, observable=True):
"""
load_model: load the model based on the configuration.
Args:
project_dir (str): The project directory.
model_name (str): The model name.
observable (boolean): Whether the model should be tracked.
"""
config = get_config(project_dir)
model = None

if config['platform'] == MODEL_OLLAMA:
model = OllamaModel(model=model_name)
if config['platform'] == MODEL_OPENAI:
model = OpenAIModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_CLAUDE:
model = ClaudeModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_MISTRAL:
model = MistralModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_DEEPSEEK:
model = DeepSeekModel(api_key=config['api_key'], model=model_name)

if observable:
return ObservableModel(model)
return model
144 changes: 144 additions & 0 deletions mle/model/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import importlib.util

from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
from mle.model.common import Model


class ClaudeModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the Claude model.
Args:
api_key (str): The Anthropic API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "anthropic"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.anthropic = importlib.import_module(dependency).Anthropic
else:
raise ImportError(
"It seems you didn't install anthropic. In order to enable the OpenAI client related features, "
"please make sure openai Python package has been installed. "
"More information, please refer to: https://docs.anthropic.com/en/api/client-sdks"
)

self.model = model if model else 'claude-3-5-sonnet-20240620'
self.model_type = 'Claude'
self.temperature = temperature
self.client = self.anthropic(api_key=api_key)
self.func_call_history = []

@staticmethod
def _add_tool_result_into_chat_history(chat_history, func, result):
"""
Add the result of tool calls into messages.
"""
return chat_history.extend([
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": func.id,
"name": func.name,
"input": func.input,
},
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": func.id,
"content": result,
},
]
},
])

def query(self, chat_history, **kwargs):
"""
Query the LLM model.

Args:
chat_history: The context (chat history).
"""
# claude has not system role in chat_history
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
system_prompt = ""
for idx, msg in enumerate(chat_history):
if msg["role"] == "system":
system_prompt += msg["content"]

# claude does not support mannual `response_format`, so we append it into system prompt
if "response_format" in kwargs.keys():
system_prompt += (
f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words"
)

# mapping the openai function_schema to claude tool_schema
tools = kwargs.get("functions",[])
for tool in tools:
if "parameters" in tool.keys():
tool["input_schema"] = tool["parameters"]
del tool["parameters"]

completion = self.client.messages.create(
max_tokens=4096,
model=self.model,
system=system_prompt,
messages=[msg for msg in chat_history if msg["role"] != "system"],
temperature=self.temperature,
stream=False,
tools=tools,
)
if completion.stop_reason == "tool_use":
for func in completion.content:
if func.type != "tool_use":
continue
function_name = process_function_name(func.name)
arguments = func.input
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
kwargs['functions'] = []
result = get_function(function_name)(**arguments)
self._add_tool_result_into_chat_history(chat_history, func, result)
return self.query(chat_history, **kwargs)
else:
return completion.content[0].text

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
# claude has not system role in chat_history
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
system_prompt = ""
for idx, msg in enumerate(chat_history):
if msg["role"] == "system":
system_prompt += msg["content"]
chat_history = [msg for msg in chat_history if msg["role"] != "system"]

# claude does not support mannual `response_format`, so we append it into system prompt
if "response_format" in kwargs.keys():
system_prompt += (
f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words"
)

with self.client.messages.stream(
max_tokens=4096,
model=self.model,
messages=chat_history,
) as stream:
for chunk in stream.text_stream:
yield chunk
18 changes: 18 additions & 0 deletions mle/model/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod


class Model(ABC):

def __init__(self):
"""
Initialize the model.
"""
self.model_type = None

@abstractmethod
def query(self, chat_history, **kwargs):
pass

@abstractmethod
def stream(self, chat_history, **kwargs):
pass
116 changes: 116 additions & 0 deletions mle/model/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import importlib.util
import json

from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
from mle.model.common import Model


class DeepSeekModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the DeepSeek model.
Args:
api_key (str): The DeepSeek API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "openai"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.openai = importlib.import_module(dependency).OpenAI
else:
raise ImportError(
"It seems you didn't install openai. In order to enable the OpenAI client related features, "
"please make sure openai Python package has been installed. "
"More information, please refer to: https://openai.com/product"
)
self.model = model if model else "deepseek-coder"
self.model_type = 'DeepSeek'
self.temperature = temperature
self.client = self.openai(
api_key=api_key, base_url="https://api.deepseek.com/beta"
)
self.func_call_history = []

def _convert_functions_to_tools(self, functions):
"""
Convert OpenAI-style functions to DeepSeek-style tools.
"""
tools = []
for func in functions:
tool = {
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func["parameters"],
},
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.

Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions", None)
tools = self._convert_functions_to_tools(functions) if functions else None
parameters = kwargs
completion = self.client.chat.completions.create(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=False,
tools=tools,
**parameters,
)

resp = completion.choices[0].message
if resp.tool_calls:
for tool_call in resp.tool_calls:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['tool_choice'] = "none"
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
return self.query(chat_history, **parameters)
else:
return resp.content

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
arguments = ""
function_name = ""
for chunk in self.client.chat.completions.create(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=True,
**kwargs,
):
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function.name:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name})
yield from self.stream(chat_history, **kwargs)
else:
yield chunk.choices[0].delta.content
Loading
Loading