Skip to content

Commit

Permalink
fix(chatgpt): 修复因全局禁用插件导致开启会话重复消费,修复多个聊天室内禁用chatgpt插件失败 (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
yqchilde authored Mar 3, 2023
1 parent 3de5d13 commit fd55bfc
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 32 deletions.
27 changes: 12 additions & 15 deletions engine/pkg/log/log.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package log

import (
"bytes"
"fmt"
"os"
"path"
"runtime"
"strings"
"strconv"

"github.com/sirupsen/logrus"
)
Expand All @@ -20,30 +21,17 @@ var log = &logger{
l: logrus.New(),
}

type Formatter struct{}

func (s *Formatter) Format(entry *logrus.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05")
level := strings.ToUpper(entry.Level.String())
if os.Getenv("DEBUG") == "true" || os.Getenv("DEBUG_LOG") == "true" {
return []byte(fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, level, log.callerFile, log.callerLine, entry.Message)), nil
} else {
return []byte(fmt.Sprintf("[%s] [%s] %s\n", timestamp, level, entry.Message)), nil
}
}

func init() {
log.l.SetLevel(logrus.TraceLevel)
log.l.SetOutput(os.Stdout)
//log.l.SetFormatter(&Formatter{})
log.l.SetReportCaller(true)
log.l.SetFormatter(&logrus.TextFormatter{
ForceColors: true,
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
CallerPrettyfier: func(frame *runtime.Frame) (function string, file string) {
if os.Getenv("DEBUG") == "true" || os.Getenv("DEBUG_LOG") == "true" {
return "", fmt.Sprintf("[%s:%d]", log.callerFile, log.callerLine)
return "", fmt.Sprintf("[%s:%d] [GOID:%d]", log.callerFile, log.callerLine, getGoId())
}
return "", ""
},
Expand Down Expand Up @@ -132,3 +120,12 @@ func Tracef(format string, args ...interface{}) {
getCaller()
log.l.Tracef(format, args...)
}

func getGoId() uint64 {
b := make([]byte, 64)
b = b[:runtime.Stack(b, false)]
b = bytes.TrimPrefix(b, []byte("goroutine "))
b = b[:bytes.IndexByte(b, ' ')]
n, _ := strconv.ParseUint(string(b), 10, 64)
return n
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/glebarez/sqlite v1.7.0
github.com/go-co-op/gocron v1.18.0
github.com/imroc/req/v3 v3.32.0
github.com/sashabaranov/go-gpt3 v1.3.1
github.com/sashabaranov/go-gpt3 v1.3.3
github.com/sirupsen/logrus v1.9.0
github.com/spf13/viper v1.15.0
github.com/tidwall/gjson v1.14.4
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ github.com/sashabaranov/go-gpt3 v1.3.0 h1:IbvaK2yTnlm7f/oiC2HC9cbzu/4Znt4GkarFiw
github.com/sashabaranov/go-gpt3 v1.3.0/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/sashabaranov/go-gpt3 v1.3.1 h1:ACQOAVX5CAV5rHt0oJOBMKo9BNcqVnmxEdjVxcjVAzw=
github.com/sashabaranov/go-gpt3 v1.3.1/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/sashabaranov/go-gpt3 v1.3.3 h1:S8Zd4YybnBaNMK+w+XGGWgsjQY1R+6QE2n9SLzVna9k=
github.com/sashabaranov/go-gpt3 v1.3.3/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/afero v1.9.3 h1:41FoI0fD7OR7mGcKE/aOiLkGreyf8ifIOQmJANWogMk=
Expand Down
53 changes: 37 additions & 16 deletions plugins/chatgpt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ import (
)

var (
db sqlite.DB // 数据库
chatCTXMap sync.Map // 群号/私聊:消息上下文
chatDone = make(chan struct{}) // 用于结束会话
db sqlite.DB // 数据库
msgContext sync.Map // 群号/私聊:消息上下文
chatRoom = make(map[string]ChatRoom) // 连续会话聊天室
)

type ChatRoom struct {
wxId string
done chan struct{}
}

// ApiKey apikey表,存放openai key
type ApiKey struct {
Key string `gorm:"column:key;index"`
Expand Down Expand Up @@ -58,7 +63,11 @@ func init() {
DataFolder: "chatgpt",
OnDisable: func(ctx *robot.Ctx) {
ctx.ReplyText("禁用成功")
chatDone <- struct{}{}
wxId := ctx.Event.FromUniqueID
if room, ok := chatRoom[wxId]; ok {
close(room.done)
delete(chatRoom, wxId)
}
},
})

Expand All @@ -76,39 +85,51 @@ func init() {

// 连续会话
engine.OnFullMatch("开始会话").SetBlock(true).Handle(func(ctx *robot.Ctx) {
wxId := ctx.Event.FromUniqueID
// 检查是否已经在进行会话
if _, ok := chatCTXMap.Load(ctx.Event.FromUniqueID); ok {
if _, ok := chatRoom[wxId]; ok {
ctx.ReplyTextAndAt("当前已经在会话中了")
return
}

var nullMessage []gogpt.ChatCompletionMessage
var (
nullMessage []gogpt.ChatCompletionMessage
room = ChatRoom{
wxId: wxId,
done: make(chan struct{}),
}
)

chatRoom[wxId] = room

// 开始会话
recv, cancel := ctx.EventChannel(ctx.CheckGroupSession()).Repeat()
defer cancel()
chatCTXMap.LoadOrStore(ctx.Event.FromUniqueID, nullMessage)
msgContext.LoadOrStore(wxId, nullMessage)
ctx.ReplyTextAndAt("收到!已开始ChatGPT连续会话中,输入\"结束会话\"结束会话,或5分钟后自动结束,请开始吧!")
for {
select {
case <-time.After(time.Minute * 5):
chatCTXMap.LoadAndDelete(ctx.Event.FromUniqueID)
msgContext.LoadAndDelete(wxId)
ctx.ReplyTextAndAt("😊检测到您已有5分钟不再提问,那我先主动结束会话咯")
return
case <-chatDone:
chatCTXMap.LoadAndDelete(ctx.Event.FromUniqueID)
ctx.ReplyTextAndAt("已退出ChatGPT")
return
case <-room.done:
if room.wxId == wxId {
msgContext.LoadAndDelete(wxId)
ctx.ReplyTextAndAt("已退出ChatGPT")
return
}
case ctx := <-recv:
wxId := ctx.Event.FromUniqueID
msg := ctx.MessageString()
if msg == "" {
continue
} else if msg == "结束会话" {
chatCTXMap.LoadAndDelete(ctx.Event.FromUniqueID)
msgContext.LoadAndDelete(wxId)
ctx.ReplyTextAndAt("已结束聊天的上下文语境,您可以重新发起提问")
return
} else if msg == "清空会话" {
chatCTXMap.Store(ctx.Event.FromUniqueID, nullMessage)
msgContext.Store(wxId, nullMessage)
ctx.ReplyTextAndAt("已清空会话,您可以继续提问新的问题")
continue
} else if strings.HasPrefix(msg, "作画") {
Expand All @@ -129,7 +150,7 @@ func init() {
}

var messages []gogpt.ChatCompletionMessage
if c, ok := chatCTXMap.Load(ctx.Event.FromUniqueID); ok {
if c, ok := msgContext.Load(wxId); ok {
messages = append(c.([]gogpt.ChatCompletionMessage), gogpt.ChatCompletionMessage{
Role: "user",
Content: msg,
Expand All @@ -152,7 +173,7 @@ func init() {
Role: "assistant",
Content: answer,
})
chatCTXMap.Store(ctx.Event.FromUniqueID, messages)
msgContext.Store(wxId, messages)
ctx.ReplyTextAndAt(answer)
}
}
Expand Down

0 comments on commit fd55bfc

Please sign in to comment.