mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
feat: better stop signal (#181)
feat: better stop signal (#181) Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
parent
9c596a983a
commit
401de5ace7
@ -2,9 +2,9 @@ package web
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
|
"chat/manager/conversation"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ type Hook func(message []globals.Message, token int) (string, error)
|
|||||||
|
|
||||||
func ChatWithWeb(message []globals.Message) []globals.Message {
|
func ChatWithWeb(message []globals.Message) []globals.Message {
|
||||||
data := utils.GetSegmentString(
|
data := utils.GetSegmentString(
|
||||||
SearchWebResult(GetPointByLatestMessage(message)), 2048,
|
SearchWebResult(message[len(message)-1].Content), 2048,
|
||||||
)
|
)
|
||||||
|
|
||||||
return utils.Insert(message, 0, globals.Message{
|
return utils.Insert(message, 0, globals.Message{
|
||||||
@ -24,52 +24,20 @@ func ChatWithWeb(message []globals.Message) []globals.Message {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func StringCleaner(content string) string {
|
func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message {
|
||||||
for _, replacer := range []string{",", "、", ",", "。", ":", ":", ";", ";", "!", "!", "=", "?", "?", "(", ")", "(", ")", "关键字", "空", "1+1"} {
|
segment := conversation.CopyMessage(instance.GetChatMessage(restart))
|
||||||
content = strings.ReplaceAll(content, replacer, " ")
|
|
||||||
|
if instance.IsEnableWeb() {
|
||||||
|
segment = ChatWithWeb(segment)
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(content)
|
|
||||||
|
return segment
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetKeywordPoint(hook Hook, message []globals.Message) string {
|
func UsingWebNativeSegment(enable bool, message []globals.Message) []globals.Message {
|
||||||
resp, _ := hook([]globals.Message{{
|
if enable {
|
||||||
Role: globals.System,
|
return ChatWithWeb(message)
|
||||||
Content: "If the user input content require ONLINE SEARCH to get the results, please output these keywords to refine the data Interval with space, remember not to answer other content, json format return, format {\"keyword\": \"...\" }",
|
} else {
|
||||||
}, {
|
return message
|
||||||
Role: globals.User,
|
|
||||||
Content: "你是谁",
|
|
||||||
}, {
|
|
||||||
Role: globals.Assistant,
|
|
||||||
Content: "{\"keyword\":\"\"}",
|
|
||||||
}, {
|
|
||||||
Role: globals.User,
|
|
||||||
Content: "那fystart起始页是什么 和深能科创有什么关系",
|
|
||||||
}, {
|
|
||||||
Role: globals.Assistant,
|
|
||||||
Content: "{\"keyword\":\"fystart起始页 深能科创 关系\"}",
|
|
||||||
}, {
|
|
||||||
Role: globals.User,
|
|
||||||
Content: "1+1=?",
|
|
||||||
}, {
|
|
||||||
Role: globals.Assistant,
|
|
||||||
Content: "{\"keyword\":\"\"}",
|
|
||||||
}, {
|
|
||||||
Role: globals.User,
|
|
||||||
Content: "?",
|
|
||||||
}, {
|
|
||||||
Role: globals.Assistant,
|
|
||||||
Content: "{\"keyword\":\"\"}",
|
|
||||||
}, {
|
|
||||||
Role: globals.User,
|
|
||||||
Content: message[len(message)-1].Content,
|
|
||||||
}}, 40)
|
|
||||||
keyword := utils.UnmarshalJson[map[string]interface{}](resp)
|
|
||||||
if keyword == nil {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
return StringCleaner(keyword["keyword"].(string))
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPointByLatestMessage(message []globals.Message) string {
|
|
||||||
return StringCleaner(message[len(message)-1].Content)
|
|
||||||
}
|
}
|
||||||
|
@ -1,68 +0,0 @@
|
|||||||
package web
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chat/utils"
|
|
||||||
"golang.org/x/net/html"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
var unexpected = []string{
|
|
||||||
"<cite>",
|
|
||||||
"<span class=\"sb_count\">",
|
|
||||||
"<div class=\"ntf_label toggle_label nt_tit\" id=\"ntf_newtabfil_label\">",
|
|
||||||
"<img role=\"presentation\"",
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseBing(source string) string {
|
|
||||||
body := SplitPagination(GetMainBody(source))
|
|
||||||
res := strings.Join(GetContent(body), " ")
|
|
||||||
return html.UnescapeString(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetMainBody(html string) string {
|
|
||||||
suf := utils.TryGet(strings.Split(html, "<main aria-label=\"搜尋結果\">"), 1)
|
|
||||||
return strings.Split(suf, "</main>")[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func SplitPagination(html string) string {
|
|
||||||
pre := strings.Split(html, "<li class=\"b_msg b_canvas\">")[0]
|
|
||||||
return utils.TryGet(strings.Split(pre, "<div class=\"ntf_label toggle_label nt_tit\" id=\"ntf_newtabfil_label\">在新选项卡中打开链接</div>"), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetContent(html string) []string {
|
|
||||||
re := regexp.MustCompile(`>([^<]+)<`)
|
|
||||||
matches := re.FindAllString(html, -1)
|
|
||||||
|
|
||||||
return FilterContent(matches)
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsExpected(data string) bool {
|
|
||||||
if IsLink(data) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, str := range unexpected {
|
|
||||||
if strings.HasPrefix(data, str) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsLink(input string) bool {
|
|
||||||
re := regexp.MustCompile(`^(https?|ftp):\/\/[^\s/$.?#].\S*$`)
|
|
||||||
return re.MatchString(input)
|
|
||||||
}
|
|
||||||
|
|
||||||
func FilterContent(matches []string) []string {
|
|
||||||
res := make([]string, 0)
|
|
||||||
|
|
||||||
for _, match := range matches {
|
|
||||||
source := strings.TrimSpace(match[1 : len(match)-1])
|
|
||||||
if len(source) > 0 && IsExpected(source) {
|
|
||||||
res = append(res, source)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
@ -4,36 +4,8 @@ import (
|
|||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetBingUrl(q string) string {
|
|
||||||
return "https://bing.com/search?q=" + url.QueryEscape(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
func RequestWithUA(url string) string {
|
|
||||||
data, err := utils.GetRaw(url, map[string]string{
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/116.0",
|
|
||||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return data
|
|
||||||
}
|
|
||||||
|
|
||||||
func SearchReverse(q string) string {
|
|
||||||
// deprecated
|
|
||||||
uri := GetBingUrl(q)
|
|
||||||
if res := CallPilotAPI(uri); res != nil {
|
|
||||||
return utils.Marshal(res.Results)
|
|
||||||
}
|
|
||||||
data := RequestWithUA(uri)
|
|
||||||
return ParseBing(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SearchWebResult(q string) string {
|
func SearchWebResult(q string) string {
|
||||||
res, err := CallDuckDuckGoAPI(q)
|
res, err := CallDuckDuckGoAPI(q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
package web
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chat/globals"
|
|
||||||
"chat/manager/conversation"
|
|
||||||
)
|
|
||||||
|
|
||||||
func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message {
|
|
||||||
segment := conversation.CopyMessage(instance.GetChatMessage(restart))
|
|
||||||
|
|
||||||
if instance.IsEnableWeb() {
|
|
||||||
segment = ChatWithWeb(segment)
|
|
||||||
}
|
|
||||||
|
|
||||||
return segment
|
|
||||||
}
|
|
||||||
|
|
||||||
func UsingWebNativeSegment(enable bool, message []globals.Message) []globals.Message {
|
|
||||||
if enable {
|
|
||||||
return ChatWithWeb(message)
|
|
||||||
} else {
|
|
||||||
return message
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,36 +0,0 @@
|
|||||||
package web
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chat/utils"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PilotResponseResult struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
Link string `json:"link"`
|
|
||||||
Snippet string `json:"snippet"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PilotResponse struct {
|
|
||||||
Results []PilotResponseResult `json:"extra_search_results" required:"true"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func GenerateFriendUID() string {
|
|
||||||
return uuid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func CallPilotAPI(url string) *PilotResponse {
|
|
||||||
data, err := utils.Post("https://webreader.webpilotai.com/api/visit-web", map[string]string{
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"WebPilot-Friend-UID": GenerateFriendUID(),
|
|
||||||
}, map[string]interface{}{
|
|
||||||
"link": url,
|
|
||||||
"user_has_request": false,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return utils.MapToStruct[PilotResponse](data)
|
|
||||||
}
|
|
182
manager/chat.go
182
manager/chat.go
@ -2,7 +2,7 @@ package manager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/adapter"
|
"chat/adapter"
|
||||||
"chat/adapter/common"
|
adaptercommon "chat/adapter/common"
|
||||||
"chat/addition/web"
|
"chat/addition/web"
|
||||||
"chat/admin"
|
"chat/admin"
|
||||||
"chat/auth"
|
"chat/auth"
|
||||||
@ -10,12 +10,19 @@ import (
|
|||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/manager/conversation"
|
"chat/manager/conversation"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
|
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultMessage = "empty response"
|
const defaultMessage = "empty response"
|
||||||
|
const interruptMessage = "interrupted"
|
||||||
|
|
||||||
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
|
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
|
||||||
db := utils.GetDBFromContext(c)
|
db := utils.GetDBFromContext(c)
|
||||||
@ -34,6 +41,148 @@ func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncount
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type partialChunk struct {
|
||||||
|
Chunk *globals.Chunk
|
||||||
|
End bool
|
||||||
|
Hit bool
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
func createStopSignal(conn *Connection) chan bool {
|
||||||
|
stopSignal := make(chan bool, 1)
|
||||||
|
go func(conn *Connection, stopSignal chan bool) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") {
|
||||||
|
stack := debug.Stack()
|
||||||
|
globals.Warn(fmt.Sprintf("caught panic from stop signal: %s\n%s", r, stack))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if conn.PeekStop() != nil {
|
||||||
|
stopSignal <- true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(conn, stopSignal)
|
||||||
|
|
||||||
|
return stopSignal
|
||||||
|
}
|
||||||
|
|
||||||
|
func createChatTask(
|
||||||
|
conn *Connection, user *auth.User, buffer *utils.Buffer, db *sql.DB, cache *redis.Client,
|
||||||
|
model string, instance *conversation.Conversation, segment []globals.Message, plan bool,
|
||||||
|
) (hit bool, err error) {
|
||||||
|
chunkChan := make(chan partialChunk, 24) // the channel to send the chunk data
|
||||||
|
interruptSignal := make(chan error, 1) // the signal to interrupt the chat task routine
|
||||||
|
stopSignal := createStopSignal(conn) // the signal to stop from the client
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// close all channels
|
||||||
|
close(interruptSignal)
|
||||||
|
close(stopSignal)
|
||||||
|
close(chunkChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// create a new chat request routine
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") {
|
||||||
|
stack := debug.Stack()
|
||||||
|
globals.Warn(fmt.Sprintf("caught panic from chat request: %s\n%s", r, stack))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
hit, err := channel.NewChatRequestWithCache(
|
||||||
|
cache, buffer,
|
||||||
|
auth.GetGroup(db, user),
|
||||||
|
&adaptercommon.ChatProps{
|
||||||
|
Model: model,
|
||||||
|
Message: segment,
|
||||||
|
MaxTokens: instance.GetMaxTokens(),
|
||||||
|
Temperature: instance.GetTemperature(),
|
||||||
|
TopP: instance.GetTopP(),
|
||||||
|
TopK: instance.GetTopK(),
|
||||||
|
PresencePenalty: instance.GetPresencePenalty(),
|
||||||
|
FrequencyPenalty: instance.GetFrequencyPenalty(),
|
||||||
|
RepetitionPenalty: instance.GetRepetitionPenalty(),
|
||||||
|
},
|
||||||
|
|
||||||
|
// the function to handle the chunk data
|
||||||
|
func(data *globals.Chunk) error {
|
||||||
|
// if interrupt signal is received
|
||||||
|
if len(interruptSignal) > 0 {
|
||||||
|
return errors.New(interruptMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the chunk data to the channel
|
||||||
|
chunkChan <- partialChunk{
|
||||||
|
Chunk: data,
|
||||||
|
End: false,
|
||||||
|
Hit: true,
|
||||||
|
Error: nil,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// chat request routine is done
|
||||||
|
chunkChan <- partialChunk{
|
||||||
|
Chunk: nil,
|
||||||
|
End: true,
|
||||||
|
Hit: hit,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case data := <-chunkChan:
|
||||||
|
if data.Error != nil && data.Error.Error() == interruptMessage {
|
||||||
|
// skip the interrupt message
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hit = data.Hit
|
||||||
|
err = data.Error
|
||||||
|
|
||||||
|
if data.End {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sendPackError := conn.SendClient(globals.ChatSegmentResponse{
|
||||||
|
Message: buffer.WriteChunk(data.Chunk),
|
||||||
|
Quota: buffer.GetQuota(),
|
||||||
|
End: false,
|
||||||
|
Plan: plan,
|
||||||
|
})
|
||||||
|
if sendPackError != nil {
|
||||||
|
globals.Warn(fmt.Sprintf("failed to send message to client: %s", sendPackError.Error()))
|
||||||
|
_ = conn.SendClient(globals.ChatSegmentResponse{
|
||||||
|
Message: sendPackError.Error(),
|
||||||
|
Quota: buffer.GetQuota(),
|
||||||
|
End: true,
|
||||||
|
Plan: plan,
|
||||||
|
})
|
||||||
|
|
||||||
|
interruptSignal <- sendPackError
|
||||||
|
|
||||||
|
return hit, sendPackError
|
||||||
|
}
|
||||||
|
case <-stopSignal:
|
||||||
|
globals.Info(fmt.Sprintf("client stopped the chat request (model: %s, client: %s)", model, conn.GetCtx().ClientIP()))
|
||||||
|
_ = conn.SendClient(globals.ChatSegmentResponse{
|
||||||
|
Quota: buffer.GetQuota(),
|
||||||
|
End: true,
|
||||||
|
Plan: plan,
|
||||||
|
})
|
||||||
|
interruptSignal <- errors.New("signal")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string {
|
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
@ -66,34 +215,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
||||||
hit, err := channel.NewChatRequestWithCache(
|
hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan)
|
||||||
cache, buffer,
|
|
||||||
auth.GetGroup(db, user),
|
|
||||||
&adaptercommon.ChatProps{
|
|
||||||
Model: model,
|
|
||||||
Message: segment,
|
|
||||||
Buffer: *buffer,
|
|
||||||
MaxTokens: instance.GetMaxTokens(),
|
|
||||||
Temperature: instance.GetTemperature(),
|
|
||||||
TopP: instance.GetTopP(),
|
|
||||||
TopK: instance.GetTopK(),
|
|
||||||
PresencePenalty: instance.GetPresencePenalty(),
|
|
||||||
FrequencyPenalty: instance.GetFrequencyPenalty(),
|
|
||||||
RepetitionPenalty: instance.GetRepetitionPenalty(),
|
|
||||||
},
|
|
||||||
func(data *globals.Chunk) error {
|
|
||||||
if signal := conn.PeekStop(); signal != nil {
|
|
||||||
// stop signal from client
|
|
||||||
return fmt.Errorf("signal")
|
|
||||||
}
|
|
||||||
return conn.SendClient(globals.ChatSegmentResponse{
|
|
||||||
Message: buffer.WriteChunk(data),
|
|
||||||
Quota: buffer.GetQuota(),
|
|
||||||
End: false,
|
|
||||||
Plan: plan,
|
|
||||||
})
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
admin.AnalysisRequest(model, buffer, err)
|
admin.AnalysisRequest(model, buffer, err)
|
||||||
if adapter.IsAvailableError(err) {
|
if adapter.IsAvailableError(err) {
|
||||||
|
Loading…
Reference in New Issue
Block a user