From 76e43b1d3e7e338782f746038f8f3a59012a0eba Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Sun, 13 Aug 2023 12:10:40 +0800 Subject: [PATCH] fix message pointer error --- api/chat.go | 2 +- connection/database.go | 3 +-- conversation/conversation.go | 4 ++++ utils/base.go | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/api/chat.go b/api/chat.go index c1dcc32..dbd65af 100644 --- a/api/chat.go +++ b/api/chat.go @@ -70,7 +70,7 @@ func ChatAPI(c *gin.Context) { return } if instance.HandleMessage(db, message) { - keyword, segment := ChatWithWeb(instance.GetMessageSegment(12), true) + keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true) SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) msg := "" diff --git a/connection/database.go b/connection/database.go index 5e362eb..2306130 100644 --- a/connection/database.go +++ b/connection/database.go @@ -101,8 +101,7 @@ func CreateConversationTable(db *sql.DB) { conversation_id INT UNIQUE, conversation_name VARCHAR(255), data TEXT, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES auth(id) + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); `) if err != nil { diff --git a/conversation/conversation.go b/conversation/conversation.go index c177d69..ccdb66d 100644 --- a/conversation/conversation.go +++ b/conversation/conversation.go @@ -63,6 +63,10 @@ func (c *Conversation) GetMessageSegment(length int) []types.ChatGPTMessage { return c.Message[len(c.Message)-length:] } +func CopyMessage(message []types.ChatGPTMessage) []types.ChatGPTMessage { + return utils.UnmarshalJson[[]types.ChatGPTMessage](utils.ToJson(message)) // deep copy +} + func (c *Conversation) GetLastMessage() types.ChatGPTMessage { return c.Message[len(c.Message)-1] } diff --git a/utils/base.go b/utils/base.go index 6fab06e..4133d79 100644 --- a/utils/base.go +++ b/utils/base.go @@ -74,7 +74,7 @@ func GetSegment[T any](arr []T, length int) []T { func GetSegmentString(arr string, length int) string { if length > len(arr) { - return "" + return arr } return arr[:length] }