feat: support apply chat params in restart action (#113)

This commit is contained in:
Zhang Minghan 2024-03-13 16:20:14 +08:00
parent 40cb56761d
commit bd9b06e420
6 changed files with 51 additions and 13 deletions

View File

@ -5,8 +5,8 @@ import (
"chat/manager/conversation" "chat/manager/conversation"
) )
func UsingWebSegment(instance *conversation.Conversation) []globals.Message { func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message {
segment := conversation.CopyMessage(instance.GetChatMessage()) segment := conversation.CopyMessage(instance.GetChatMessage(restart))
if instance.IsEnableWeb() { if instance.IsEnableWeb() {
segment = ChatWithWeb(segment) segment = ChatWithWeb(segment)

View File

@ -128,11 +128,12 @@ export class Connection {
}); });
} }
public sendEvent(t: any, event: string, data?: string) { public sendEvent(t: any, event: string, data?: string, props?: ChatProps) {
this.sendWithRetry(t, { this.sendWithRetry(t, {
type: event, type: event,
message: data || "", message: data || "",
model: "event", model: "event",
...props,
}); });
} }
@ -140,8 +141,8 @@ export class Connection {
this.sendEvent(t, "stop"); this.sendEvent(t, "stop");
} }
public sendRestartEvent(t: any) { public sendRestartEvent(t: any, data?: ChatProps) {
this.sendEvent(t, "restart"); this.sendEvent(t, "restart", undefined, data);
} }
public sendMaskEvent(t: any, mask: Mask) { public sendMaskEvent(t: any, mask: Mask) {
@ -239,9 +240,9 @@ export class ConnectionStack {
conn && conn.sendStopEvent(t); conn && conn.sendStopEvent(t);
} }
public sendRestartEvent(id: number, t: any) { public sendRestartEvent(id: number, t: any, data?: ChatProps) {
const conn = this.getConnection(id); const conn = this.getConnection(id);
conn && conn.sendRestartEvent(t); conn && conn.sendRestartEvent(t, data);
} }
public sendMaskEvent(id: number, t: any, mask: Mask) { public sendMaskEvent(id: number, t: any, mask: Mask) {

View File

@ -550,7 +550,20 @@ export function useMessageActions() {
if (!stack.hasConnection(current)) { if (!stack.hasConnection(current)) {
stack.createConnection(current); stack.createConnection(current);
} }
stack.sendRestartEvent(current, t); stack.sendRestartEvent(current, t, {
web,
model,
context: history,
ignore_context: !context,
max_tokens,
temperature,
top_p,
top_k,
presence_penalty,
frequency_penalty,
repetition_penalty,
message: "",
});
// remove the last message if it's from assistant and create a new message // remove the last message if it's from assistant and create a new message
dispatch(restartMessage(current)); dispatch(restartMessage(current));

View File

@ -56,7 +56,7 @@ func MockStreamSender(conn *Connection, message string) {
}) })
} }
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string { func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
stack := debug.Stack() stack := debug.Stack()
@ -70,7 +70,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
cache := conn.GetCache() cache := conn.GetCache()
model := instance.GetModel() model := instance.GetModel()
segment := adapter.ClearMessages(model, web.UsingWebSegment(instance)) segment := adapter.ClearMessages(model, web.UsingWebSegment(instance, restart))
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model)
conn.Send(globals.ChatSegmentResponse{ conn.Send(globals.ChatSegmentResponse{

View File

@ -233,7 +233,28 @@ func (c *Conversation) GetMessageSegment(length int) []globals.Message {
return c.Message[len(c.Message)-length:] return c.Message[len(c.Message)-length:]
} }
func (c *Conversation) GetChatMessage() []globals.Message { func (c *Conversation) GetChatMessage(restart bool) []globals.Message {
if restart {
// remove all last `assistant` role message
cp := CopyMessage(c.Message)
var index int
for index = len(cp) - 1; index >= 0; index-- {
if cp[index].Role != globals.Assistant {
break
}
}
if index >= 0 {
cp = cp[:index+1]
}
if c.GetContextLength() > len(cp) {
return cp
}
return cp[len(cp)-c.GetContextLength():]
}
return c.GetMessageSegment(c.GetContextLength()) return c.GetMessageSegment(c.GetContextLength())
} }

View File

@ -83,7 +83,7 @@ func ChatAPI(c *gin.Context) {
switch form.Type { switch form.Type {
case ChatType: case ChatType:
if instance.HandleMessage(db, form) { if instance.HandleMessage(db, form) {
response := ChatHandler(buf, user, instance) response := ChatHandler(buf, user, instance, false)
instance.SaveResponse(db, response) instance.SaveResponse(db, response)
} }
case StopType: case StopType:
@ -91,7 +91,10 @@ func ChatAPI(c *gin.Context) {
case ShareType: case ShareType:
instance.LoadSharing(db, form.Message) instance.LoadSharing(db, form.Message)
case RestartType: case RestartType:
response := ChatHandler(buf, user, instance) // reset the params if set
instance.ApplyParam(form)
response := ChatHandler(buf, user, instance, true)
instance.SaveResponse(db, response) instance.SaveResponse(db, response)
case MaskType: case MaskType:
instance.LoadMask(form.Message) instance.LoadMask(form.Message)