From bd9b06e42022c383ad8cb795079310627981ab0c Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Wed, 13 Mar 2024 16:20:14 +0800 Subject: [PATCH] feat: support apply chat params in restart action (#113) --- addition/web/utils.go | 4 ++-- app/src/api/connection.ts | 11 ++++++----- app/src/store/chat.ts | 15 ++++++++++++++- manager/chat.go | 4 ++-- manager/conversation/conversation.go | 23 ++++++++++++++++++++++- manager/manager.go | 7 +++++-- 6 files changed, 51 insertions(+), 13 deletions(-) diff --git a/addition/web/utils.go b/addition/web/utils.go index 70603ed..b56d58c 100644 --- a/addition/web/utils.go +++ b/addition/web/utils.go @@ -5,8 +5,8 @@ import ( "chat/manager/conversation" ) -func UsingWebSegment(instance *conversation.Conversation) []globals.Message { - segment := conversation.CopyMessage(instance.GetChatMessage()) +func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message { + segment := conversation.CopyMessage(instance.GetChatMessage(restart)) if instance.IsEnableWeb() { segment = ChatWithWeb(segment) diff --git a/app/src/api/connection.ts b/app/src/api/connection.ts index cbe9777..f56272a 100644 --- a/app/src/api/connection.ts +++ b/app/src/api/connection.ts @@ -128,11 +128,12 @@ export class Connection { }); } - public sendEvent(t: any, event: string, data?: string) { + public sendEvent(t: any, event: string, data?: string, props?: ChatProps) { this.sendWithRetry(t, { type: event, message: data || "", model: "event", + ...props, }); } @@ -140,8 +141,8 @@ export class Connection { this.sendEvent(t, "stop"); } - public sendRestartEvent(t: any) { - this.sendEvent(t, "restart"); + public sendRestartEvent(t: any, data?: ChatProps) { + this.sendEvent(t, "restart", undefined, data); } public sendMaskEvent(t: any, mask: Mask) { @@ -239,9 +240,9 @@ export class ConnectionStack { conn && conn.sendStopEvent(t); } - public sendRestartEvent(id: number, t: any) { + public sendRestartEvent(id: number, t: any, data?: ChatProps) { const conn = this.getConnection(id); - conn && conn.sendRestartEvent(t); + conn && conn.sendRestartEvent(t, data); } public sendMaskEvent(id: number, t: any, mask: Mask) { diff --git a/app/src/store/chat.ts b/app/src/store/chat.ts index 42aa46e..430de08 100644 --- a/app/src/store/chat.ts +++ b/app/src/store/chat.ts @@ -550,7 +550,20 @@ export function useMessageActions() { if (!stack.hasConnection(current)) { stack.createConnection(current); } - stack.sendRestartEvent(current, t); + stack.sendRestartEvent(current, t, { + web, + model, + context: history, + ignore_context: !context, + max_tokens, + temperature, + top_p, + top_k, + presence_penalty, + frequency_penalty, + repetition_penalty, + message: "", + }); // remove the last message if it's from assistant and create a new message dispatch(restartMessage(current)); diff --git a/manager/chat.go b/manager/chat.go index fd0d85b..36e44f0 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -56,7 +56,7 @@ func MockStreamSender(conn *Connection, message string) { }) } -func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string { +func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string { defer func() { if err := recover(); err != nil { stack := debug.Stack() @@ -70,7 +70,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve cache := conn.GetCache() model := instance.GetModel() - segment := adapter.ClearMessages(model, web.UsingWebSegment(instance)) + segment := adapter.ClearMessages(model, web.UsingWebSegment(instance, restart)) check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) conn.Send(globals.ChatSegmentResponse{ diff --git a/manager/conversation/conversation.go b/manager/conversation/conversation.go index 7472645..c2251e3 100644 --- a/manager/conversation/conversation.go +++ b/manager/conversation/conversation.go @@ -233,7 +233,28 @@ func (c *Conversation) GetMessageSegment(length int) []globals.Message { return c.Message[len(c.Message)-length:] } -func (c *Conversation) GetChatMessage() []globals.Message { +func (c *Conversation) GetChatMessage(restart bool) []globals.Message { + if restart { + // remove all last `assistant` role message + cp := CopyMessage(c.Message) + + var index int + for index = len(cp) - 1; index >= 0; index-- { + if cp[index].Role != globals.Assistant { + break + } + } + if index >= 0 { + cp = cp[:index+1] + } + + if c.GetContextLength() > len(cp) { + return cp + } + + return cp[len(cp)-c.GetContextLength():] + } + return c.GetMessageSegment(c.GetContextLength()) } diff --git a/manager/manager.go b/manager/manager.go index add6541..86bcac6 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -83,7 +83,7 @@ func ChatAPI(c *gin.Context) { switch form.Type { case ChatType: if instance.HandleMessage(db, form) { - response := ChatHandler(buf, user, instance) + response := ChatHandler(buf, user, instance, false) instance.SaveResponse(db, response) } case StopType: @@ -91,7 +91,10 @@ func ChatAPI(c *gin.Context) { case ShareType: instance.LoadSharing(db, form.Message) case RestartType: - response := ChatHandler(buf, user, instance) + // reset the params if set + instance.ApplyParam(form) + + response := ChatHandler(buf, user, instance, true) instance.SaveResponse(db, response) case MaskType: instance.LoadMask(form.Message)