Skip to content

Commit

Permalink
feat: 加入GPT-3.5-turbo模型,和chatgpt一样 (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
yqchilde authored Mar 2, 2023
1 parent 4599eaa commit 86cee2a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 73 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/glebarez/sqlite v1.7.0
github.com/go-co-op/gocron v1.18.0
github.com/imroc/req/v3 v3.32.0
github.com/sashabaranov/go-gpt3 v1.2.1
github.com/sashabaranov/go-gpt3 v1.3.0
github.com/sirupsen/logrus v1.9.0
github.com/spf13/viper v1.15.0
github.com/tidwall/gjson v1.14.4
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/sashabaranov/go-gpt3 v1.2.1 h1:kfU+vQ1ThI7p+xfwwJC8olEEEWjK3smgKZ3FcYbaLRQ=
github.com/sashabaranov/go-gpt3 v1.2.1/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/sashabaranov/go-gpt3 v1.3.0 h1:IbvaK2yTnlm7f/oiC2HC9cbzu/4Znt4GkarFiwZ60uI=
github.com/sashabaranov/go-gpt3 v1.3.0/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/afero v1.9.3 h1:41FoI0fD7OR7mGcKE/aOiLkGreyf8ifIOQmJANWogMk=
Expand Down
2 changes: 1 addition & 1 deletion plugins/chatgpt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

### 配置参数

* 插件名:ChatGPT聊天
* 插件名:ChatGPT聊天,已支持`GPT-3.5-turbo`模型
* 权限:所有好友和群聊
* 数据来源:https://beta.openai.com
* 注意:请先私聊机器人配置`appKey`,相关秘钥申请地址点上面链接
Expand Down
47 changes: 31 additions & 16 deletions plugins/chatgpt/gpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package chatgpt
import (
"context"
"errors"
"fmt"
"strings"
"time"

gogpt "github.com/sashabaranov/go-gpt3"
"github.com/yqchilde/wxbot/engine/robot"

"github.com/yqchilde/wxbot/engine/pkg/log"
)
Expand All @@ -16,7 +18,7 @@ var (
gptModel *GptModel
)

func AskChatGpt(prompt string, delay ...time.Duration) (answer string, err error) {
func AskChatGpt(messages []gogpt.ChatCompletionMessage, delay ...time.Duration) (answer string, err error) {
// 获取客户端
if gptClient == nil {
gptClient, err = getGptClient()
Expand All @@ -39,18 +41,31 @@ func AskChatGpt(prompt string, delay ...time.Duration) (answer string, err error
}

// 请求gpt3
resp, err := gptClient.CreateCompletion(context.Background(), gogpt.CompletionRequest{
Model: gptModel.Model,
Prompt: prompt,
MaxTokens: gptModel.MaxTokens,
Temperature: float32(gptModel.Temperature),
TopP: float32(gptModel.TopP),
PresencePenalty: float32(gptModel.PresencePenalty),
FrequencyPenalty: float32(gptModel.FrequencyPenalty),
Echo: false,
Stop: []string{"Human:", "AI:"},
})
//resp, err := gptClient.CreateCompletion(context.Background(), gogpt.CompletionRequest{
// Model: gptModel.Model,
// Prompt: prompt,
// MaxTokens: gptModel.MaxTokens,
// Temperature: float32(gptModel.Temperature),
// TopP: float32(gptModel.TopP),
// PresencePenalty: float32(gptModel.PresencePenalty),
// FrequencyPenalty: float32(gptModel.FrequencyPenalty),
// Echo: false,
// Stop: []string{"Human:", "AI:"},
//})

chatMessages := []gogpt.ChatCompletionMessage{
{
Role: "system",
Content: fmt.Sprintf("你是一个强大的助手,你是ChatGPT,我将为你起一个名字叫%s,并且你会用中文回答我的问题", robot.GetBot().GetConfig().BotNickname),
},
}
chatMessages = append(chatMessages, messages...)

log.Println("chatMessages: ", chatMessages)
resp, err := gptClient.CreateChatCompletion(context.Background(), gogpt.ChatCompletionRequest{
Model: gptModel.Model,
Messages: chatMessages,
})
// 处理响应回来的错误
if err != nil {
if strings.Contains(err.Error(), "You exceeded your current quota") {
Expand All @@ -61,15 +76,15 @@ func AskChatGpt(prompt string, delay ...time.Duration) (answer string, err error
}
apiKeys = apiKeys[1:]
gptClient = gogpt.NewClient(apiKeys[0].Key)
return AskChatGpt(prompt)
return AskChatGpt(messages)
}
if strings.Contains(err.Error(), "The server had an error while processing your request") {
log.Println("OpenAi服务出现问题,将重试")
return AskChatGpt(prompt)
return AskChatGpt(messages)
}
if strings.Contains(err.Error(), "Client.Timeout exceeded while awaiting headers") {
log.Println("OpenAi服务请求超时,将重试")
return AskChatGpt(prompt)
return AskChatGpt(messages)
}
if strings.Contains(err.Error(), "Please reduce your prompt") {
return "", errors.New("OpenAi免费上下文长度限制为4097个词组,您的上下文长度已超出限制,请发送\"清空会话\"以清空上下文")
Expand All @@ -79,7 +94,7 @@ func AskChatGpt(prompt string, delay ...time.Duration) (answer string, err error
}
return "", err
}
return resp.Choices[0].Text + "\n", nil
return resp.Choices[0].Message.Content, nil
}

// filterAnswer 过滤答案,处理一些符号问题
Expand Down
92 changes: 37 additions & 55 deletions plugins/chatgpt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ type GptModel struct {
}

var defaultGptModel = GptModel{
Model: "text-davinci-003",
MaxTokens: 512,
Temperature: 0.9,
TopP: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.6,
ImageSize: "512x512",
Model: "gpt-3.5-turbo",
ImageSize: "512x512",
}

func init() {
Expand Down Expand Up @@ -82,10 +77,12 @@ func init() {
return
}

var nullMessage []gogpt.ChatCompletionMessage

// 开始会话
recv, cancel := ctx.EventChannel(ctx.CheckGroupSession()).Repeat()
defer cancel()
chatCTXMap.LoadOrStore(ctx.Event.FromUniqueID, "")
chatCTXMap.LoadOrStore(ctx.Event.FromUniqueID, nullMessage)
ctx.ReplyTextAndAt("收到!已开始ChatGPT连续会话中,输入\"结束会话\"结束会话,或5分钟后自动结束,请开始吧!")
for {
select {
Expand All @@ -106,7 +103,7 @@ func init() {
ctx.ReplyTextAndAt("已结束聊天的上下文语境,您可以重新发起提问")
return
} else if msg == "清空会话" {
chatCTXMap.Store(ctx.Event.FromUniqueID, "")
chatCTXMap.Store(ctx.Event.FromUniqueID, nullMessage)
ctx.ReplyTextAndAt("已清空会话,您可以继续提问新的问题")
continue
} else if strings.HasPrefix(msg, "作画") {
Expand All @@ -126,53 +123,53 @@ func init() {
continue
}

// 整理问题
question := "Human: " + msg + "\nAI: "
var messages []gogpt.ChatCompletionMessage
if c, ok := chatCTXMap.Load(ctx.Event.FromUniqueID); ok {
question = c.(string) + question
messages = append(c.([]gogpt.ChatCompletionMessage), gogpt.ChatCompletionMessage{
Role: "user",
Content: msg,
})
} else {
messages = []gogpt.ChatCompletionMessage{
{
Role: "user",
Content: msg,
},
}
}
answer, err := AskChatGpt(question, 2*time.Second)

answer, err := AskChatGpt(messages, 2*time.Second)
if err != nil {
ctx.ReplyTextAndAt("ChatGPT出错了,Err:" + err.Error())
continue
}
chatCTXMap.Store(ctx.Event.FromUniqueID, question+answer)
if newAnswer, isNeedReply := filterAnswer(answer); isNeedReply {
retryAnswer, err := AskChatGpt(question + "\n" + answer + newAnswer)
if err != nil {
ctx.ReplyTextAndAt("ChatGPT出错了,Err:" + err.Error())
continue
}
chatCTXMap.Store(ctx.Event.FromUniqueID, question+"\n"+answer)
ctx.ReplyTextAndAt(retryAnswer)
} else {
ctx.ReplyTextAndAt(newAnswer)
}
messages = append(messages, gogpt.ChatCompletionMessage{
Role: "assistant",
Content: answer,
})
chatCTXMap.Store(ctx.Event.FromUniqueID, messages)
ctx.ReplyTextAndAt(answer)
}
}
})

// 单独提问,没有上下文处理
engine.OnRegex(`^提问 (.*)$`).SetBlock(true).Handle(func(ctx *robot.Ctx) {
questionRaw := ctx.State["regex_matched"].([]string)[1]
question := "Human: " + questionRaw + "\nAI: "
answer, err := AskChatGpt(question, time.Second)
question := ctx.State["regex_matched"].([]string)[1]

messages := []gogpt.ChatCompletionMessage{
{
Role: "user",
Content: question,
},
}
answer, err := AskChatGpt(messages, time.Second)
if err != nil {
log.Errorf("ChatGPT出错了,Err:%s", err.Error())
ctx.ReplyTextAndAt("ChatGPT出错了,Err:" + err.Error())
return
}
if newAnswer, isNeedRetry := filterAnswer(answer); isNeedRetry {
retryAnswer, err := AskChatGpt(question + "\n" + answer + newAnswer)
if err != nil {
log.Errorf("ChatGPT出错了,Err:%s", err.Error())
ctx.ReplyTextAndAt("ChatGPT出错了,Err:" + err.Error())
return
}
ctx.ReplyTextAndAt(fmt.Sprintf("问:%s \n--------------------\n答:%s", questionRaw, retryAnswer))
} else {
ctx.ReplyTextAndAt(fmt.Sprintf("问:%s \n--------------------\n答:%s", questionRaw, newAnswer))
}
ctx.ReplyTextAndAt(fmt.Sprintf("问:%s \n--------------------\n答:%s", question, answer))
})

// AI作画
Expand Down Expand Up @@ -255,16 +252,6 @@ func init() {
switch k {
case "ModelName":
updates["model"] = v
case "MaxTokens":
updates["max_tokens"] = v
case "Temperature":
updates["temperature"] = v
case "TopP":
updates["top_p"] = v
case "FrequencyPenalty":
updates["frequency_penalty"] = v
case "PresencePenalty":
updates["presence_penalty"] = v
case "ImageSize":
updates["image_size"] = v
default:
Expand Down Expand Up @@ -293,13 +280,8 @@ func init() {
replyMsg := ""
replyMsg += "----------\n"
replyMsg += "ModelName: %s\n"
replyMsg += "MaxTokens: %d\n"
replyMsg += "Temperature: %.2f\n"
replyMsg += "TopP: %.2f\n"
replyMsg += "FrequencyPenalty: %.2f\n"
replyMsg += "PresencePenalty: %.2f\n"
replyMsg += "ImageSize: %s\n----------\n"
replyMsg = fmt.Sprintf(replyMsg, gptModel.Model, gptModel.MaxTokens, gptModel.Temperature, gptModel.TopP, gptModel.FrequencyPenalty, gptModel.PresencePenalty, gptModel.ImageSize)
replyMsg = fmt.Sprintf(replyMsg, gptModel.Model, gptModel.ImageSize)

// key设置
var keys []ApiKey
Expand Down

0 comments on commit 86cee2a

Please sign in to comment.