-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #63 from InternLM/mistral
[add] mistral-7b model
- Loading branch information
Showing
17 changed files
with
189 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from fastapi import APIRouter, Request, Response | ||
|
||
from openaoe.backend.service.mistral import Mistral | ||
from openaoe.backend.model.openaoe import AoeChatBody | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/v1/mistral/chat", tags=["Mistral"]) | ||
async def mistral_chat(body: AoeChatBody, request: Request, response: Response): | ||
""" | ||
chat api for Mistral 7b model | ||
:param body: request body | ||
:param request: request | ||
:param response: response | ||
:return | ||
""" | ||
return await Mistral(request, response).chat(body) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
ref. to https://github.com/ollama/ollama/blob/main/docs/api.md | ||
Parameters | ||
model: (required) the model name | ||
messages: the messages of the chat, this can be used to keep a chat memory | ||
The message object has the following fields: | ||
role: the role of the message, either system, user or assistant | ||
content: the content of the message | ||
images (optional): a list of images to include in the message (for multimodal models such as llava) | ||
Advanced parameters (optional): | ||
format: the format to return a response in. Currently the only accepted value is json | ||
options: additional model parameters listed in the documentation for the Modelfile such as temperature | ||
template: the prompt template to use (overrides what is defined in the Modelfile) | ||
stream: if false the response will be returned as a single response object, rather than a stream of objects | ||
keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m) | ||
""" | ||
|
||
from typing import List, Optional, Literal, Dict | ||
from pydantic import BaseModel | ||
|
||
|
||
class Message(BaseModel): | ||
role: Optional[Literal["user", "system", "assistant"]] = "user" | ||
content: str | ||
images: Optional[List[str]] = None # img in base64 | ||
|
||
|
||
class MistralChatBody(BaseModel): | ||
model: str | ||
messages: List[Message] | ||
options: Optional[Dict] = {} | ||
template: Optional[str] = None | ||
stream: Optional[bool] = True | ||
keep_alive: Optional[str] = '5m' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Optional, List, Literal | ||
from pydantic import BaseModel | ||
|
||
|
||
class Context(BaseModel): | ||
send_type: str = 'assistant' | ||
sender_type: str = "assistant" | ||
text: str = '' | ||
|
||
|
||
class RoleMeta(BaseModel): | ||
user_name: Optional[str] = 'user' | ||
bot_name: Optional[str] = 'assistant' | ||
|
||
|
||
class AoeChatBody(BaseModel): | ||
""" | ||
OpenAOE general request body | ||
""" | ||
model: str | ||
prompt: str | ||
messages: Optional[List[Context]] = [] | ||
role_meta: Optional[RoleMeta] = None | ||
type: Optional[Literal['text', 'json']] = 'json' | ||
stream: Optional[bool] = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import json | ||
|
||
import requests | ||
from fastapi import Request, Response | ||
from sse_starlette import EventSourceResponse | ||
|
||
from openaoe.backend.config.biz_config import get_base_url | ||
from openaoe.backend.config.constant import PROVIDER_MISTRAL | ||
from openaoe.backend.model.openaoe import AoeChatBody | ||
from openaoe.backend.model.mistral import MistralChatBody, Message | ||
|
||
from openaoe.backend.util.log import log | ||
logger = log(__name__) | ||
|
||
|
||
class Mistral: | ||
def __init__(self, request: Request, response: Response): | ||
self.request = request | ||
self.response = response | ||
|
||
async def chat(self, body: AoeChatBody): | ||
msgs = [] | ||
for msg in body.messages: | ||
m = Message(role=msg.sender_type if msg.sender_type != 'bot' else "assistant", content=msg.text) | ||
msgs.append(m) | ||
last_m = Message(role='user', content=body.prompt) | ||
msgs.append(last_m) | ||
chat_url = get_base_url(PROVIDER_MISTRAL, body.model) + "/api/chat" | ||
chat_body = MistralChatBody( | ||
model="mistral", | ||
messages=msgs | ||
) | ||
return self.chat_response_streaming(chat_url, chat_body) | ||
|
||
def chat_response_streaming(self, chat_url: str, chat_body: MistralChatBody): | ||
async def do_response_streaming(): | ||
try: | ||
res = requests.post(chat_url, json=json.loads(chat_body.model_dump_json()), stream=True) | ||
if res: | ||
for chunk in res.iter_content(chunk_size=512, decode_unicode=True): | ||
chunk = bytes.decode(chunk) | ||
logger.info(f"chunk: {chunk}") | ||
chunk_json = json.loads(chunk) | ||
yield json.dumps({ | ||
"success": True, | ||
"msg": chunk_json.get("message").get("content") | ||
}, ensure_ascii=False) | ||
except Exception as e: | ||
logger.error(f"{e}") | ||
yield json.dumps( | ||
{ | ||
"success": "false", | ||
"msg": f"from backend: {e}" | ||
} | ||
) | ||
|
||
return EventSourceResponse(do_response_streaming()) | ||
|
||
|
||
|
||
|
Oops, something went wrong.