From dff51128d9d393f416ab07201b161aeeefe3d190 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Mon, 14 Aug 2023 14:57:18 +0800 Subject: [PATCH] fix multi-user unique fields --- connection/database.go | 5 +++-- conversation/storage.go | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/connection/database.go b/connection/database.go index 2306130..7d4edaa 100644 --- a/connection/database.go +++ b/connection/database.go @@ -98,10 +98,11 @@ func CreateConversationTable(db *sql.DB) { CREATE TABLE IF NOT EXISTS conversation ( id INT PRIMARY KEY AUTO_INCREMENT, user_id INT, - conversation_id INT UNIQUE, + conversation_id INT, conversation_name VARCHAR(255), data TEXT, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY (user_id, conversation_id) ); `) if err != nil { diff --git a/conversation/storage.go b/conversation/storage.go index 6dd20fd..d88add1 100644 --- a/conversation/storage.go +++ b/conversation/storage.go @@ -4,18 +4,30 @@ import ( "chat/types" "chat/utils" "database/sql" + "log" ) func (c *Conversation) SaveConversation(db *sql.DB) bool { data := utils.ToJson(c.GetMessage()) - _, err := db.Exec("INSERT INTO conversation (user_id, conversation_id, conversation_name, data) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE conversation_name = ?, data = ?", c.UserID, c.Id, c.Name, data, c.Name, data) + query := "INSERT INTO conversation (user_id, conversation_id, conversation_name, data) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE conversation_name = VALUES(conversation_name), data = VALUES(data)" + stmt, err := db.Prepare(query) + if err != nil { + return false + } + defer func(stmt *sql.Stmt) { + err := stmt.Close() + if err != nil { + log.Println(err) + } + }(stmt) + + _, err = stmt.Exec(c.UserID, c.Id, c.Name, data) if err != nil { return false } return true } - func GetConversationLengthByUserID(db *sql.DB, userId int64) int64 { var length int64 err := db.QueryRow("SELECT COUNT(*) FROM conversation WHERE user_id = ?", userId).Scan(&length)