Skip to content

Commit

Permalink
Merge pull request #229 from leeeizhang/lei/seperate-models
Browse files Browse the repository at this point in the history
[MRG] seperate LLM models
  • Loading branch information
huangyz0918 authored Oct 8, 2024
2 parents 98d72a4 + ecd171a commit 985c24e
Show file tree
Hide file tree
Showing 8 changed files with 620 additions and 595 deletions.
595 changes: 0 additions & 595 deletions mle/model.py

This file was deleted.

70 changes: 70 additions & 0 deletions mle/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from .anthropic import *
from .deepseek import *
from .mistral import *
from .ollama import *
from .openai import *

from mle.utils import get_config


MODEL_OLLAMA = 'Ollama'
MODEL_OPENAI = 'OpenAI'
MODEL_CLAUDE = 'Claude'
MODEL_MISTRAL = 'MistralAI'
MODEL_DEEPSEEK = 'DeepSeek'


class ObservableModel:
"""
A class that wraps a model to make it trackable by the metric platform (e.g., Langfuse).
"""

try:
from mle.utils import get_langfuse_observer
_observe = get_langfuse_observer()
except Exception as e:
# If importing fails, set _observe to a lambda function that does nothing.
_observe = lambda fn: fn

def __init__(self, model: Model):
"""
Initialize the ObservableModel.
Args:
model: The model to be wrapped and made observable.
"""
self.model = model

@_observe
def query(self, *args, **kwargs):
return self.model.query(*args, **kwargs)

@_observe
def stream(self, *args, **kwargs):
return self.model.query(*args, **kwargs)


def load_model(project_dir: str, model_name: str=None, observable=True):
"""
load_model: load the model based on the configuration.
Args:
project_dir (str): The project directory.
model_name (str): The model name.
observable (boolean): Whether the model should be tracked.
"""
config = get_config(project_dir)
model = None

if config['platform'] == MODEL_OLLAMA:
model = OllamaModel(model=model_name)
if config['platform'] == MODEL_OPENAI:
model = OpenAIModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_CLAUDE:
model = ClaudeModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_MISTRAL:
model = MistralModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_DEEPSEEK:
model = DeepSeekModel(api_key=config['api_key'], model=model_name)

if observable:
return ObservableModel(model)
return model
144 changes: 144 additions & 0 deletions mle/model/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import importlib.util

from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
from mle.model.common import Model


class ClaudeModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the Claude model.
Args:
api_key (str): The Anthropic API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "anthropic"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.anthropic = importlib.import_module(dependency).Anthropic
else:
raise ImportError(
"It seems you didn't install anthropic. In order to enable the OpenAI client related features, "
"please make sure openai Python package has been installed. "
"More information, please refer to: https://docs.anthropic.com/en/api/client-sdks"
)

self.model = model if model else 'claude-3-5-sonnet-20240620'
self.model_type = 'Claude'
self.temperature = temperature
self.client = self.anthropic(api_key=api_key)
self.func_call_history = []

@staticmethod
def _add_tool_result_into_chat_history(chat_history, func, result):
"""
Add the result of tool calls into messages.
"""
return chat_history.extend([
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": func.id,
"name": func.name,
"input": func.input,
},
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": func.id,
"content": result,
},
]
},
])

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Args:
chat_history: The context (chat history).
"""
# claude has not system role in chat_history
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
system_prompt = ""
for idx, msg in enumerate(chat_history):
if msg["role"] == "system":
system_prompt += msg["content"]

# claude does not support mannual `response_format`, so we append it into system prompt
if "response_format" in kwargs.keys():
system_prompt += (
f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words"
)

# mapping the openai function_schema to claude tool_schema
tools = kwargs.get("functions",[])
for tool in tools:
if "parameters" in tool.keys():
tool["input_schema"] = tool["parameters"]
del tool["parameters"]

completion = self.client.messages.create(
max_tokens=4096,
model=self.model,
system=system_prompt,
messages=[msg for msg in chat_history if msg["role"] != "system"],
temperature=self.temperature,
stream=False,
tools=tools,
)
if completion.stop_reason == "tool_use":
for func in completion.content:
if func.type != "tool_use":
continue
function_name = process_function_name(func.name)
arguments = func.input
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
kwargs['functions'] = []
result = get_function(function_name)(**arguments)
self._add_tool_result_into_chat_history(chat_history, func, result)
return self.query(chat_history, **kwargs)
else:
return completion.content[0].text

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
# claude has not system role in chat_history
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
system_prompt = ""
for idx, msg in enumerate(chat_history):
if msg["role"] == "system":
system_prompt += msg["content"]
chat_history = [msg for msg in chat_history if msg["role"] != "system"]

# claude does not support mannual `response_format`, so we append it into system prompt
if "response_format" in kwargs.keys():
system_prompt += (
f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words"
)

with self.client.messages.stream(
max_tokens=4096,
model=self.model,
messages=chat_history,
) as stream:
for chunk in stream.text_stream:
yield chunk
18 changes: 18 additions & 0 deletions mle/model/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod


class Model(ABC):

def __init__(self):
"""
Initialize the model.
"""
self.model_type = None

@abstractmethod
def query(self, chat_history, **kwargs):
pass

@abstractmethod
def stream(self, chat_history, **kwargs):
pass
116 changes: 116 additions & 0 deletions mle/model/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import importlib.util
import json

from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
from mle.model.common import Model


class DeepSeekModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the DeepSeek model.
Args:
api_key (str): The DeepSeek API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "openai"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.openai = importlib.import_module(dependency).OpenAI
else:
raise ImportError(
"It seems you didn't install openai. In order to enable the OpenAI client related features, "
"please make sure openai Python package has been installed. "
"More information, please refer to: https://openai.com/product"
)
self.model = model if model else "deepseek-coder"
self.model_type = 'DeepSeek'
self.temperature = temperature
self.client = self.openai(
api_key=api_key, base_url="https://api.deepseek.com/beta"
)
self.func_call_history = []

def _convert_functions_to_tools(self, functions):
"""
Convert OpenAI-style functions to DeepSeek-style tools.
"""
tools = []
for func in functions:
tool = {
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func["parameters"],
},
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions", None)
tools = self._convert_functions_to_tools(functions) if functions else None
parameters = kwargs
completion = self.client.chat.completions.create(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=False,
tools=tools,
**parameters,
)

resp = completion.choices[0].message
if resp.tool_calls:
for tool_call in resp.tool_calls:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['tool_choice'] = "none"
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
return self.query(chat_history, **parameters)
else:
return resp.content

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
arguments = ""
function_name = ""
for chunk in self.client.chat.completions.create(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=True,
**kwargs,
):
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function.name:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name})
yield from self.stream(chat_history, **kwargs)
else:
yield chunk.choices[0].delta.content
Loading

0 comments on commit 985c24e

Please sign in to comment.