feat: better stop signal (#181)

feat: better stop signal (#181)
Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
Deng Junhai 2024-06-27 22:57:17 +08:00
parent 9c596a983a
commit 401de5ace7
7 changed files with 166 additions and 232 deletions

View File

@ -2,9 +2,9 @@ package web
import ( import (
"chat/globals" "chat/globals"
"chat/manager/conversation"
"chat/utils" "chat/utils"
"fmt" "fmt"
"strings"
"time" "time"
) )
@ -12,7 +12,7 @@ type Hook func(message []globals.Message, token int) (string, error)
func ChatWithWeb(message []globals.Message) []globals.Message { func ChatWithWeb(message []globals.Message) []globals.Message {
data := utils.GetSegmentString( data := utils.GetSegmentString(
SearchWebResult(GetPointByLatestMessage(message)), 2048, SearchWebResult(message[len(message)-1].Content), 2048,
) )
return utils.Insert(message, 0, globals.Message{ return utils.Insert(message, 0, globals.Message{
@ -24,52 +24,20 @@ func ChatWithWeb(message []globals.Message) []globals.Message {
}) })
} }
func StringCleaner(content string) string { func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message {
for _, replacer := range []string{",", "、", "", "。", "", ":", "", ";", "", "!", "=", "", "?", "", "", "(", ")", "关键字", "空", "1+1"} { segment := conversation.CopyMessage(instance.GetChatMessage(restart))
content = strings.ReplaceAll(content, replacer, " ")
if instance.IsEnableWeb() {
segment = ChatWithWeb(segment)
} }
return strings.TrimSpace(content)
return segment
} }
func GetKeywordPoint(hook Hook, message []globals.Message) string { func UsingWebNativeSegment(enable bool, message []globals.Message) []globals.Message {
resp, _ := hook([]globals.Message{{ if enable {
Role: globals.System, return ChatWithWeb(message)
Content: "If the user input content require ONLINE SEARCH to get the results, please output these keywords to refine the data Interval with space, remember not to answer other content, json format return, format {\"keyword\": \"...\" }", } else {
}, { return message
Role: globals.User,
Content: "你是谁",
}, {
Role: globals.Assistant,
Content: "{\"keyword\":\"\"}",
}, {
Role: globals.User,
Content: "那fystart起始页是什么 和深能科创有什么关系",
}, {
Role: globals.Assistant,
Content: "{\"keyword\":\"fystart起始页 深能科创 关系\"}",
}, {
Role: globals.User,
Content: "1+1=?",
}, {
Role: globals.Assistant,
Content: "{\"keyword\":\"\"}",
}, {
Role: globals.User,
Content: "?",
}, {
Role: globals.Assistant,
Content: "{\"keyword\":\"\"}",
}, {
Role: globals.User,
Content: message[len(message)-1].Content,
}}, 40)
keyword := utils.UnmarshalJson[map[string]interface{}](resp)
if keyword == nil {
return ""
} }
return StringCleaner(keyword["keyword"].(string))
}
func GetPointByLatestMessage(message []globals.Message) string {
return StringCleaner(message[len(message)-1].Content)
} }

View File

@ -1,68 +0,0 @@
package web
import (
"chat/utils"
"golang.org/x/net/html"
"regexp"
"strings"
)
var unexpected = []string{
"<cite>",
"<span class=\"sb_count\">",
"<div class=\"ntf_label toggle_label nt_tit\" id=\"ntf_newtabfil_label\">",
"<img role=\"presentation\"",
}
func ParseBing(source string) string {
body := SplitPagination(GetMainBody(source))
res := strings.Join(GetContent(body), " ")
return html.UnescapeString(res)
}
func GetMainBody(html string) string {
suf := utils.TryGet(strings.Split(html, "<main aria-label=\"搜尋結果\">"), 1)
return strings.Split(suf, "</main>")[0]
}
func SplitPagination(html string) string {
pre := strings.Split(html, "<li class=\"b_msg b_canvas\">")[0]
return utils.TryGet(strings.Split(pre, "<div class=\"ntf_label toggle_label nt_tit\" id=\"ntf_newtabfil_label\">在新选项卡中打开链接</div>"), 1)
}
func GetContent(html string) []string {
re := regexp.MustCompile(`>([^<]+)<`)
matches := re.FindAllString(html, -1)
return FilterContent(matches)
}
func IsExpected(data string) bool {
if IsLink(data) {
return false
}
for _, str := range unexpected {
if strings.HasPrefix(data, str) {
return false
}
}
return true
}
func IsLink(input string) bool {
re := regexp.MustCompile(`^(https?|ftp):\/\/[^\s/$.?#].\S*$`)
return re.MatchString(input)
}
func FilterContent(matches []string) []string {
res := make([]string, 0)
for _, match := range matches {
source := strings.TrimSpace(match[1 : len(match)-1])
if len(source) > 0 && IsExpected(source) {
res = append(res, source)
}
}
return res
}

View File

@ -4,36 +4,8 @@ import (
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"fmt" "fmt"
"net/url"
) )
func GetBingUrl(q string) string {
return "https://bing.com/search?q=" + url.QueryEscape(q)
}
func RequestWithUA(url string) string {
data, err := utils.GetRaw(url, map[string]string{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/116.0",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
})
if err != nil {
return ""
}
return data
}
func SearchReverse(q string) string {
// deprecated
uri := GetBingUrl(q)
if res := CallPilotAPI(uri); res != nil {
return utils.Marshal(res.Results)
}
data := RequestWithUA(uri)
return ParseBing(data)
}
func SearchWebResult(q string) string { func SearchWebResult(q string) string {
res, err := CallDuckDuckGoAPI(q) res, err := CallDuckDuckGoAPI(q)
if err != nil { if err != nil {

View File

@ -1,24 +0,0 @@
package web
import (
"chat/globals"
"chat/manager/conversation"
)
func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message {
segment := conversation.CopyMessage(instance.GetChatMessage(restart))
if instance.IsEnableWeb() {
segment = ChatWithWeb(segment)
}
return segment
}
func UsingWebNativeSegment(enable bool, message []globals.Message) []globals.Message {
if enable {
return ChatWithWeb(message)
} else {
return message
}
}

View File

@ -1,36 +0,0 @@
package web
import (
"chat/utils"
"github.com/google/uuid"
)
type PilotResponseResult struct {
Title string `json:"title"`
Link string `json:"link"`
Snippet string `json:"snippet"`
}
type PilotResponse struct {
Results []PilotResponseResult `json:"extra_search_results" required:"true"`
}
func GenerateFriendUID() string {
return uuid.New().String()
}
func CallPilotAPI(url string) *PilotResponse {
data, err := utils.Post("https://webreader.webpilotai.com/api/visit-web", map[string]string{
"Content-Type": "application/json",
"WebPilot-Friend-UID": GenerateFriendUID(),
}, map[string]interface{}{
"link": url,
"user_has_request": false,
})
if err != nil {
return nil
}
return utils.MapToStruct[PilotResponse](data)
}

View File

@ -2,7 +2,7 @@ package manager
import ( import (
"chat/adapter" "chat/adapter"
"chat/adapter/common" adaptercommon "chat/adapter/common"
"chat/addition/web" "chat/addition/web"
"chat/admin" "chat/admin"
"chat/auth" "chat/auth"
@ -10,12 +10,19 @@ import (
"chat/globals" "chat/globals"
"chat/manager/conversation" "chat/manager/conversation"
"chat/utils" "chat/utils"
"database/sql"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"runtime/debug" "runtime/debug"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
) )
const defaultMessage = "empty response" const defaultMessage = "empty response"
const interruptMessage = "interrupted"
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) { func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
db := utils.GetDBFromContext(c) db := utils.GetDBFromContext(c)
@ -34,6 +41,148 @@ func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncount
} }
} }
type partialChunk struct {
Chunk *globals.Chunk
End bool
Hit bool
Error error
}
func createStopSignal(conn *Connection) chan bool {
stopSignal := make(chan bool, 1)
go func(conn *Connection, stopSignal chan bool) {
defer func() {
if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") {
stack := debug.Stack()
globals.Warn(fmt.Sprintf("caught panic from stop signal: %s\n%s", r, stack))
}
}()
for {
if conn.PeekStop() != nil {
stopSignal <- true
break
}
}
}(conn, stopSignal)
return stopSignal
}
func createChatTask(
conn *Connection, user *auth.User, buffer *utils.Buffer, db *sql.DB, cache *redis.Client,
model string, instance *conversation.Conversation, segment []globals.Message, plan bool,
) (hit bool, err error) {
chunkChan := make(chan partialChunk, 24) // the channel to send the chunk data
interruptSignal := make(chan error, 1) // the signal to interrupt the chat task routine
stopSignal := createStopSignal(conn) // the signal to stop from the client
defer func() {
// close all channels
close(interruptSignal)
close(stopSignal)
close(chunkChan)
}()
// create a new chat request routine
go func() {
defer func() {
if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") {
stack := debug.Stack()
globals.Warn(fmt.Sprintf("caught panic from chat request: %s\n%s", r, stack))
}
}()
hit, err := channel.NewChatRequestWithCache(
cache, buffer,
auth.GetGroup(db, user),
&adaptercommon.ChatProps{
Model: model,
Message: segment,
MaxTokens: instance.GetMaxTokens(),
Temperature: instance.GetTemperature(),
TopP: instance.GetTopP(),
TopK: instance.GetTopK(),
PresencePenalty: instance.GetPresencePenalty(),
FrequencyPenalty: instance.GetFrequencyPenalty(),
RepetitionPenalty: instance.GetRepetitionPenalty(),
},
// the function to handle the chunk data
func(data *globals.Chunk) error {
// if interrupt signal is received
if len(interruptSignal) > 0 {
return errors.New(interruptMessage)
}
// send the chunk data to the channel
chunkChan <- partialChunk{
Chunk: data,
End: false,
Hit: true,
Error: nil,
}
return nil
},
)
// chat request routine is done
chunkChan <- partialChunk{
Chunk: nil,
End: true,
Hit: hit,
Error: err,
}
}()
for {
select {
case data := <-chunkChan:
if data.Error != nil && data.Error.Error() == interruptMessage {
// skip the interrupt message
continue
}
hit = data.Hit
err = data.Error
if data.End {
return
}
sendPackError := conn.SendClient(globals.ChatSegmentResponse{
Message: buffer.WriteChunk(data.Chunk),
Quota: buffer.GetQuota(),
End: false,
Plan: plan,
})
if sendPackError != nil {
globals.Warn(fmt.Sprintf("failed to send message to client: %s", sendPackError.Error()))
_ = conn.SendClient(globals.ChatSegmentResponse{
Message: sendPackError.Error(),
Quota: buffer.GetQuota(),
End: true,
Plan: plan,
})
interruptSignal <- sendPackError
return hit, sendPackError
}
case <-stopSignal:
globals.Info(fmt.Sprintf("client stopped the chat request (model: %s, client: %s)", model, conn.GetCtx().ClientIP()))
_ = conn.SendClient(globals.ChatSegmentResponse{
Quota: buffer.GetQuota(),
End: true,
Plan: plan,
})
interruptSignal <- errors.New("signal")
return
}
}
}
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string { func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
@ -66,34 +215,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
} }
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
hit, err := channel.NewChatRequestWithCache( hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan)
cache, buffer,
auth.GetGroup(db, user),
&adaptercommon.ChatProps{
Model: model,
Message: segment,
Buffer: *buffer,
MaxTokens: instance.GetMaxTokens(),
Temperature: instance.GetTemperature(),
TopP: instance.GetTopP(),
TopK: instance.GetTopK(),
PresencePenalty: instance.GetPresencePenalty(),
FrequencyPenalty: instance.GetFrequencyPenalty(),
RepetitionPenalty: instance.GetRepetitionPenalty(),
},
func(data *globals.Chunk) error {
if signal := conn.PeekStop(); signal != nil {
// stop signal from client
return fmt.Errorf("signal")
}
return conn.SendClient(globals.ChatSegmentResponse{
Message: buffer.WriteChunk(data),
Quota: buffer.GetQuota(),
End: false,
Plan: plan,
})
},
)
admin.AnalysisRequest(model, buffer, err) admin.AnalysisRequest(model, buffer, err)
if adapter.IsAvailableError(err) { if adapter.IsAvailableError(err) {