Skip to content

Commit

Permalink
Merge pull request #58 from InternLM/feat_stream
Browse files Browse the repository at this point in the history
Feat stream
  • Loading branch information
fly2tomato authored Feb 2, 2024
2 parents 920ad8c + 88ff6a0 commit c494cf7
Show file tree
Hide file tree
Showing 16 changed files with 274 additions and 46 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ English | [简体中文](docs/README_zh-CN.md)


## Latest Progress 🎉
- \[February 2024\] Add gemini-pro model
- \[January 2024\] refactor the config-template.yaml to control the backend and the frontend settings at the same time, [click](https://github.com/InternLM/OpenAOE/blob/main/docs/tech-report/config-template.md) to find more introduction about the `config-template.yaml`
- \[January 2024\] Add internlm2-chat-7b model
- \[January 2024\] Released version v0.0.1, officially open source!
Expand Down
1 change: 1 addition & 0 deletions docs/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


## 最新进展 🎉
- \[2024/02\] 添加gemini-pro模型
- \[2024/01\] 重构了config-template.yaml,可以同时配置前后端的设置
- \[2024/01\] 添加 internlm2-chat-7b 模型
- \[2024/01\] 发布v0.0.1版本,正式开源。
Expand Down
2 changes: 1 addition & 1 deletion openaoe/backend/api/route_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ async def palm_chat(body: GooglePalmChatBody):
@param body: request body
@return: response
"""
ret = palm_chat_svc(body)
ret = await palm_chat_svc(body)
return ret
12 changes: 7 additions & 5 deletions openaoe/backend/config/biz_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

import argparse
import os
import sys
Expand All @@ -8,7 +10,7 @@
from openaoe.backend.util.log import log

logger = log(__name__)
biz_config = None
BIZ_CONFIG = None


class BizConfig:
Expand Down Expand Up @@ -71,14 +73,14 @@ def load_config(config_path) -> BizConfig:
logger.error("init configuration failed. Exit")
sys.exit(-1)

global biz_config
biz_config = BizConfig(**m)
global BIZ_CONFIG
BIZ_CONFIG = BizConfig(**m)
logger.info("init configuration successfully.")
return biz_config
return BIZ_CONFIG


def get_model_configuration(provider: str, field, model_name: str = None):
models = biz_config.models_map
models = BIZ_CONFIG.models_map
if not models:
logger.error(f"invalid configuration file")
sys.exit(-1)
Expand Down
9 changes: 9 additions & 0 deletions openaoe/backend/config/config-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ models:
api:
api_base: https://generativelanguage.googleapis.com
api_key:
gemini-pro:
provider: google
webui:
avatar: 'https://oss.openmmlab.com/frontend/OpenAOE/google-palm.webp'
isStream: true
background: 'linear-gradient(#b55a5a26 0%, #fa5ab1 100%)'
api:
api_base: https://generativelanguage.googleapis.com
api_key:
abab5-chat:
provider: minimax
webui:
Expand Down
11 changes: 10 additions & 1 deletion openaoe/backend/model/aoe_response.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pydantic import BaseModel
#!/usr/bin/env python3

from typing import Optional

from pydantic import BaseModel


class AOEResponse(BaseModel):
"""
Expand All @@ -11,3 +14,9 @@ class AOEResponse(BaseModel):
data: Optional[object] = None


class StreamResponse(BaseModel):
"""
Standard OpenAOE stream response
"""
success: Optional[bool] = True
msg: Optional[str] = ""
1 change: 1 addition & 0 deletions openaoe/backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pyyaml==6.0.1
httpx==0.25.0
sse-starlette==1.8.2
anyio==3.7.1
jsonstreamer==1.3.8
135 changes: 135 additions & 0 deletions openaoe/backend/service/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
import json
import sys
import traceback
from io import StringIO
from typing import Callable

import httpx
from fastapi.encoders import jsonable_encoder
from jsonstreamer import ObjectStreamer

from openaoe.backend.config.constant import DEFAULT_TIMEOUT_SECONDS
from openaoe.backend.model.aoe_response import AOEResponse, StreamResponse
from openaoe.backend.util.log import log

logger = log(__name__)


async def base_request(provider: str, url: str, method: str, headers: dict, body=None, timeout=DEFAULT_TIMEOUT_SECONDS,
params=None,
files=None) -> AOEResponse:
"""
common request function for http
Args:
provider: use for log
url: complete url
method: request method
headers: request headers, excluding user-agent, host and ip.
body: json body only
timeout: seconds
params: request params
files: request file
Returns:
AOEResponse
"""
response = AOEResponse()

headers_pure = {}
for k, v in headers:
k = k.lower()
if k == "user-agent" or k == "host" or "ip" in k:
continue
headers_pure[k] = v

body_str = body
if "content-type" in headers and "multipart/form-data" in headers["content-type"]:
body_str = "image"
if len(body_str) > 200:
body_str = body_str[:200]

try:
async with httpx.AsyncClient() as client:
proxy = await client.request(method, url, headers=headers, json=body, timeout=timeout,
params=params, files=files)
response.data = proxy.content
try:
response.data = json.loads(response.data)
except:
response.data = proxy.content

except Exception as e:
response.msg = str(e)
logger.error(f"[{provider}] url: {url}, method: {method}, headers: {jsonable_encoder(headers)}, "
f"body: {body_str} failed, response: {jsonable_encoder(response)}")
return response


async def base_stream(provider: str, url: str, method: str, headers: dict, stream_callback: Callable, body=None,
timeout=DEFAULT_TIMEOUT_SECONDS,
params=None,
files=None):
"""
common stream request
Args:
stream_callback:
provider: use for log
url: complete url
method: request method
headers: request headers, excluding user-agent, host and ip.
stream_callback: use ObjectStream to stream parse json, this method will be executed while any stream received,
use print to output(we have redirected stdout to response stream)
body: json body only
timeout: seconds
params: request params
files: request file
Returns:
SSE response with StreamResponse json string
"""
headers_pure = {
"Content-Type": "application/json"
}
for k, v in headers:
k = k.lower()
if k == "user-agent" or k == "host" or "ip" in k:
continue
headers_pure[k] = v

body_str = jsonable_encoder(body)
if "content-type" in headers and "multipart/form-data" in headers["content-type"]:
body_str = "image"
if len(body_str) > 200:
body_str = body_str[:200]

try:
with httpx.stream(method, url, json=body, params=params, files=files, headers=headers_pure,
timeout=timeout) as res:
if res.status_code != 200:
raise Exception(f"request failed, model status code: {res.status_code}")

# stream parser
streamer = ObjectStreamer()
sys.stdout = mystdout = StringIO()
streamer.add_catch_all_listener(stream_callback)

for text in res.iter_text():
streamer.consume(text)
res = mystdout.getvalue()
stream_res = json.dumps(jsonable_encoder(StreamResponse(msg=res)))
# format res
yield stream_res
# clear printed string
sys.stdout.seek(0)
sys.stdout.truncate()

except Exception as e:
print(traceback.format_exc())
res = json.dumps(jsonable_encoder(StreamResponse(
success=False,
msg=str(e)
)))
logger.error(f"[{provider}] url: {url}, method: {method}, headers: {jsonable_encoder(headers_pure)}, "
f"body: {body_str} failed, response: {res}")
yield res
106 changes: 83 additions & 23 deletions openaoe/backend/service/service_google.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,29 @@
import json

import requests
from jsonstreamer import ObjectStreamer
from sse_starlette import EventSourceResponse

from openaoe.backend.config.biz_config import get_api_key, get_base_url
from openaoe.backend.config.constant import *
from openaoe.backend.model.aoe_response import AOEResponse
from openaoe.backend.model.google import GooglePalmChatBody
from openaoe.backend.service.base import base_request, base_stream
from openaoe.backend.util.log import log

logger = log(__name__)


def palm_chat_svc(body: GooglePalmChatBody):
async def palm_chat_svc(body: GooglePalmChatBody):
"""
chat logic for google PaLM model
"""
api_key = get_api_key(PROVIDER_GOOGLE, body.model)
url = get_base_url(PROVIDER_GOOGLE, body.model)
url = f"{url}/google/v1beta2/models/{body.model}:generateMessage?key={api_key}"
messages = [
{"content": msg.content, "author": msg.author}
for msg in body.prompt.messages or []
]
body = {
"prompt": {
"messages": messages
},
"temperature": body.temperature,
"candidate_count": body.candidate_count,
"model": body.model
}
url, params, request_body = _construct_request_data(body)

if "gemini" in body.model:
return EventSourceResponse(
base_stream(PROVIDER_GOOGLE, url, "post", {}, _catch_all, body=request_body, params=params))

try:
response_json = requests.post(
url=url,
data=json.dumps(body)
).json()
response = await base_request(PROVIDER_GOOGLE, url, "post", {}, request_body, params=params)
response_json = response.data
if response_json.get('error') is not None:
err_msg = response_json.get('error').get("message")
return AOEResponse(
Expand All @@ -57,3 +46,74 @@ def palm_chat_svc(body: GooglePalmChatBody):
data=str(e)
)
return base


def _construct_request_data(body: GooglePalmChatBody):
model = body.model
api_base = get_base_url(PROVIDER_GOOGLE, body.model)
api_key = get_api_key(PROVIDER_GOOGLE, body.model)

if "gemini" in model:
# specially process gemini request
url = f"{api_base}/v1beta/models/{model}:streamGenerateContent"
params = {
"key": api_key
}

if body.prompt.messages[0].author == "1":
body.prompt.messages = body.prompt.messages[1:]
contents = [
{
"role": "user" if msg.author == "0" else "model",
"parts": [{
"text": msg.content
}]
}
for msg in body.prompt.messages
]
body = {
"contents": contents,
"generationConfig": {
"temperature": body.temperature,
"candidateCount": 1
}
}
else:
url = f"{api_base}/v1beta2/models/{model}:generateMessage?key={api_key}"
params = {
"key": api_key
}

messages = []
last_author = ""
# ignore not answered prompt
for item in body.prompt.messages:
if item.author == last_author:
messages.pop()
messages.append({"content": item.content, "author": item.author})
last_author = item.author
body = {
"prompt": {
"messages": messages
},
"temperature": body.temperature,
"candidate_count": body.candidate_count,
"model": body.model
}
return url, params, body


def _catch_all(event_name, *args):
if event_name == ObjectStreamer.PAIR_EVENT and args[0] == "text":
print(args[1])
return

elif event_name != ObjectStreamer.ELEMENT_EVENT:
return

for item in args:
try:
text = item["candidates"][0]["content"]["parts"][0]["text"]
print(text)
except:
logger.warning(f"parse error, raw: {item}")
9 changes: 9 additions & 0 deletions openaoe/backend/util/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,13 @@ def log(name):
return logger


def clear_other_log():
for name, item in logging.Logger.manager.loggerDict.items():
if not isinstance(item, logging.Logger):
continue
if "aoe" not in name:
item.setLevel(logging.CRITICAL)


clear_other_log()
logger = log("util")
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import React from 'react';

export const DefaultConfigInfo = {
models: null,
streamProviders: []
streamModels: []
};

export const GlobalConfigContext = React.createContext(DefaultConfigInfo);
Loading

0 comments on commit c494cf7

Please sign in to comment.