Skip to content

Commit

Permalink
feat: 兼容vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
daizeyao committed Dec 21, 2024
1 parent 78a45e2 commit d0e3a9b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 33 deletions.
4 changes: 2 additions & 2 deletions web/OJ/chat-test.php
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<?php
// echo "Test Chat API<br><br>" . PHP_EOL;
// $endpoint = 'http://172.22.233.71:8081/OJ/chat.php';
// $numberOfRequests = 20; // Number of concurrent requests
// $numberOfRequests = 50; // Number of concurrent requests
// $question = '写个C语言的累加程序';

// $multiHandle = curl_multi_init();
Expand All @@ -20,7 +20,7 @@
// 'Referer: http://172.22.233.71:8081/OJ/',
// 'User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36 Edg/131.0.0.0'
// ]);
// curl_setopt($curlHandles[$i], CURLOPT_COOKIE, 'PHPSESSID=v252sdchi5438pmte0gt8p2oi7');
// curl_setopt($curlHandles[$i], CURLOPT_COOKIE, 'PHPSESSID=dc9ja08djppauf1m5uvpcb73b5');
// curl_setopt($curlHandles[$i], CURLOPT_SSL_VERIFYPEER, false); // --insecure option
// curl_multi_add_handle($multiHandle, $curlHandles[$i]);
// }
Expand Down
27 changes: 16 additions & 11 deletions web/OJ/chat.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@

require_once './include/static.php';
require './class/Class.DFA.php';
require './class/Class.StreamHandler.php';
require './class/Class.OllamaChat.php';
require './class/Class.AICore.php';

echo 'data: ' . json_encode(['time' => date('Y-m-d H:i:s'), 'content' => '']) . PHP_EOL . PHP_EOL;
flush();
Expand All @@ -69,28 +68,34 @@
}
$question = str_ireplace('{[$add$]}', '+', $question);

// api 和 模型选择
$chat = new OllamaChat(
"http://$AI_HOST:11434/api/generate",
rand(1, 100) <= 50 ? "$AI_MODEL1" : "$AI_MODEL2"
);
// api 和 模型选择 和 交互模式
// $chat = new AICore([
// "url" => "http://$AI_HOST:11434/api/chat",
// "model" => rand(1, 100) <= 50 ? "$AI_MODEL1" : "$AI_MODEL1",
// "type" => "chat",
// "stream" => true
// ]);
$chat = new AICore([
// "url" => "http://$AI_HOST:". (rand(1, 100) <= 50 ? "8000" : "8001") ."/v1/chat/completions",
"url" => "http://$AI_HOST:8000/v1/chat/completions",
"model" => "Qwen2.5-7B-Instruct",
"type" => "vllm-chat",
"stream" => true
]);

$DOCUMENT_ROOT = $_SERVER['DOCUMENT_ROOT'];
$dfa = new DFA([
'words_file' => "$DOCUMENT_ROOT/OJ/plugins/hznuojai/dict.txt",
]);
$chat->set_dfa($dfa);


// 开始提问
$chat->qa([
'system' => '你是杭州师范大学在线测评系统的智能代码助手,你负责且只负责回答代码相关的问题,并且使用中文回答,代码部分使用```包围,下面是问题:',
'question' => $question,
]);


// echo "*************************************" . PHP_EOL;
unset($_SESSION['last_chat_time']);

if(file_exists('./include/cache_end.php'))
if (file_exists('./include/cache_end.php'))
require_once('./include/cache_end.php');
40 changes: 28 additions & 12 deletions web/OJ/class/Class.OllamaChat.php → web/OJ/class/Class.AICore.php
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
<?php
require_once './class/Class.StreamHandler.php';

class OllamaChat
class AICore
{
private $api_url = '';
private $streamHandler;
private $question;
private $dfa = NULL;
private $check_sensitive = TRUE;
private $model = '';
private $api_type = '';
private $more_params = [];

public function __construct($url, $model)
public function __construct($params)
{
$this->api_url = $url;
$this->model = $model;
$this->api_url = $params['url'];
$this->model = $params['model'];
$this->api_type = $params['type'];
$this->more_params = array_diff_key($params, array_flip(['url', 'model', 'type']));
}

public function set_dfa(&$dfa)
Expand All @@ -28,7 +33,8 @@ public function qa($params)

$this->question = $params['system'] . $params['question'];
$this->streamHandler = new StreamHandler([
'qmd5' => md5($this->question . '' . time())
'qmd5' => md5($this->question . '' . time()),
'api_type' => $this->api_type
]);
if ($this->check_sensitive) {
$this->streamHandler->set_dfa($this->dfa);
Expand All @@ -40,20 +46,30 @@ public function qa($params)
return;
}

// 根据Ollama API的要求构建请求正文
$json = json_encode([
'prompt' => $this->question,
'model' => $this->model,
]);
// 构建请求 json
if ($this->api_type == 'generate') {
$json = json_encode(array_merge([
'model' => $this->model,
'prompt' => $this->question
], $this->more_params));
} else if ($this->api_type == 'chat' || $this->api_type == 'vllm-chat') {
$json = json_encode(array_merge([
'model' => $this->model,
'messages' => [[
"role" => "system",
"content" => $this->question
]],
], $this->more_params));
}

$headers = array(
"Content-Type: application/json",
);

$this->ollamaApiCall($json, $headers);
$this->openaiApiCall($json, $headers);
}

private function ollamaApiCall($json, $headers)
private function openaiApiCall($json, $headers)
{
// 注意 curl 需要开启 php 拓展
$ch = curl_init();
Expand Down
57 changes: 50 additions & 7 deletions web/OJ/class/Class.StreamHandler.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class StreamHandler
private $punctuation; //停顿符号
private $dfa = NULL;
private $check_sensitive = FALSE;
private $api_type = '';

public function __construct($params)
{
Expand All @@ -17,6 +18,7 @@ public function __construct($params)
$this->qmd5 = $params['qmd5'] ?? time();
$this->chars = [];
$this->punctuation = ['', '', '', '', '', '……'];
$this->api_type = $params['api_type'] ?? '';
}

public function set_dfa(&$dfa)
Expand All @@ -27,14 +29,48 @@ public function set_dfa(&$dfa)
}
}

//data: {"id":"chat-8f9daf1b8b0045a79204658a0b4cde5e","object":"chat.completion.chunk","created":1734703016,"model":"Qwen2.5-7B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}]}
// 转化为{"model":"codellama:13b","created_at":"2024-12-20T14:09:59.126319875Z","message":{"role":"assistant","content":"bot"},"done":false}这种格式
public function vllmToOpenAI($result)
{
if (trim($result) == "data: [DONE]") {
return json_encode(['done' => true]);
}
if (strpos($result, 'data: ') === 0) {
$result = substr($result, 6);
}
$decoded = json_decode($result, true);
if (isset($decoded['choices'][0]['delta']['content'])) {
$content = $decoded['choices'][0]['delta']['content'];
} else {
$content = '';
}

$openAIFormat = [
'model' => $decoded['model'],
'created_at' => date('c', $decoded['created']),
'message' => [
'role' => 'assistant',
'content' => $content
],
'done' => false
];

return json_encode($openAIFormat);
}

public function callback($ch, $data)
{
$origin_data = $data;
if ($this->api_type == 'vllm-chat') {
$data = $this->vllmToOpenAI($data);
}
$this->counter += 1;
file_put_contents('./log/data.' . $this->qmd5 . '.log', $this->counter . '==' . $data . PHP_EOL . '--------------------' . PHP_EOL, FILE_APPEND);

// echo $data;

// $result = json_decode($data, TRUE);
$result = json_decode($data, TRUE);
// echo $origin_data . PHP_EOL;
// print_r($result);

// if (is_array($result)) {
// $this->end('openai 请求错误:' . json_encode($result));
Expand All @@ -55,19 +91,26 @@ public function callback($ch, $data)

$line_data = json_decode($line, TRUE);

if ($line_data['done'] == true) {
if (isset($line_data['done']) && $line_data['done'] == true) {
//数据传输结束
$this->data_buffer = '';
$this->counter = 0;
$this->sensitive_check();
$this->end();
break;
}

$this->sensitive_check($line_data['response']);
if ($this->api_type == 'generate') {
$content = $line_data['response'] ?? NULL;
} else if ($this->api_type == 'chat' || $this->api_type == 'vllm-chat') {
$content = $line_data['message']['content'] ?? NULL;
}
if ($content) {
$this->sensitive_check($content);
}
// echo 'content: ' . $content . PHP_EOL;
}

return strlen($data);
return strlen($origin_data); // 返回值对应原始函数
}

private function sensitive_check($content = NULL)
Expand Down
2 changes: 1 addition & 1 deletion web/OJ/plugins/hznuojai/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class ChatCore {
this.lastLastWord = this.lastWord;
this.lastWord = content;

let isBottom = ((this.msgDiv.scrollHeight - this.msgDiv.clientHeight) - this.msgDiv.scrollTop) < 50;
let isBottom = ((this.msgDiv.scrollHeight - this.msgDiv.clientHeight) - this.msgDiv.scrollTop) < 60;
this.throttledScrollToBottom(isBottom);

this.typingIdx += 1;
Expand Down

0 comments on commit d0e3a9b

Please sign in to comment.