diff --git a/api/chat.go b/api/chat.go index 1832302..3e70484 100644 --- a/api/chat.go +++ b/api/chat.go @@ -4,7 +4,9 @@ import ( "chat/auth" "chat/conversation" "chat/middleware" + "chat/types" "chat/utils" + "database/sql" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "net/http" @@ -14,6 +16,10 @@ type WebsocketAuthForm struct { Token string `json:"token" binding:"required"` } +func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentResponse) { + _ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message))) +} + func ChatAPI(c *gin.Context) { // websocket connection upgrader := websocket.Upgrader{ @@ -55,34 +61,25 @@ func ChatAPI(c *gin.Context) { return } - instance := conversation.NewConversation(user.Username, user.ID) + db := c.MustGet("db").(*sql.DB) + instance := conversation.NewConversation(db, user.ID) for { _, message, err = conn.ReadMessage() if err != nil { return } - if _, err := instance.AddMessageFromUserForm(message); err == nil { + if instance.HandleMessage(db, message) { keyword, segment := ChatWithWeb(instance.GetMessageSegment(12), true) - _ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(map[string]interface{}{ - "keyword": keyword, - "message": "", - "end": false, - }))) + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) - StreamRequest("gpt-3.5-turbo-16k", segment, 2000, func(resp string) { - data := utils.ToJson(map[string]interface{}{ - "keyword": keyword, - "message": resp, - "end": false, + StreamRequest("gpt-3.5-turbo-16k-0613", segment, 2000, func(resp string) { + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: resp, + End: false, }) - _ = conn.WriteMessage(websocket.TextMessage, []byte(data)) }) - data := utils.ToJson(map[string]interface{}{ - "message": "", - "end": true, - }) - _ = conn.WriteMessage(websocket.TextMessage, []byte(data)) + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true}) } } } diff --git a/app/src/App.vue b/app/src/App.vue index 2eaa509..bf2c77c 100644 --- a/app/src/App.vue +++ b/app/src/App.vue @@ -6,6 +6,7 @@ import Star from "./components/icons/star.vue"; import { mobile, gpt4 } from "./assets/script/shared"; import Post from "./components/icons/post.vue"; import Github from "./components/icons/github.vue"; +import Heart from "./components/icons/heart.vue"; function goto() { @@ -40,6 +41,10 @@ function toggle(n: boolean) { GPT-4 +
@@ -155,6 +160,35 @@ aside { width: max-content; } +.donate-container { + display: flex; + flex-direction: row; + align-items: center; + margin: 6px auto; + padding: 6px 8px; + vertical-align: center; + border-radius: 8px; + gap: 8px; + width: 235px; + background: rgba(220, 119, 127, 0.25); + transition: .25s; + cursor: pointer; + color: rgb(255, 110, 122); + font-size: 16px; + justify-content: center; + user-select: none; +} + +.donate-container:hover { + background: rgba(255, 110, 122, .3); +} + +.donate-container svg { + width: 32px; + height: 32px; + stroke: rgb(255, 110, 122); +} + .model { display: flex; align-items: center; @@ -280,6 +314,10 @@ aside { flex-direction: column; } + .donate-container { + display: none; + } + .logo { display: none; } diff --git a/app/src/components/icons/heart.vue b/app/src/components/icons/heart.vue new file mode 100644 index 0000000..57a99fc --- /dev/null +++ b/app/src/components/icons/heart.vue @@ -0,0 +1,3 @@ + diff --git a/app/src/views/HomeView.vue b/app/src/views/HomeView.vue index 340a872..7afebfd 100644 --- a/app/src/views/HomeView.vue +++ b/app/src/views/HomeView.vue @@ -160,7 +160,7 @@ onMounted(() => { .input input { width: 100%; height: 32px; - margin: 4px 16px; + margin: 4px 16px 8px; color: var(--card-text); background: var(--card-input); border: 1px solid var(--card-input-border); diff --git a/connection/database.go b/connection/database.go index 668b89b..2a2deb7 100644 --- a/connection/database.go +++ b/connection/database.go @@ -8,11 +8,11 @@ import ( "log" ) -var Database *sql.DB +var _ *sql.DB func ConnectMySQL() *sql.DB { // connect to MySQL - Database, err := sql.Open("mysql", fmt.Sprintf( + db, err := sql.Open("mysql", fmt.Sprintf( "%s:%s@tcp(%s:%d)/%s", viper.GetString("mysql.user"), viper.GetString("mysql.password"), @@ -26,11 +26,12 @@ func ConnectMySQL() *sql.DB { log.Println("Connected to MySQL server successfully") } - CreateUserTable(Database) - CreateSubscriptionTable(Database) - CreatePackageTable(Database) - CreatePaymentLogTable(Database) - return Database + CreateUserTable(db) + CreateConversationTable(db) + CreateSubscriptionTable(db) + CreatePackageTable(db) + CreatePaymentLogTable(db) + return db } func CreateUserTable(db *sql.DB) { @@ -91,3 +92,20 @@ func CreatePackageTable(db *sql.DB) { log.Fatal(err) } } + +func CreateConversationTable(db *sql.DB) { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS conversation ( + id INT PRIMARY KEY AUTO_INCREMENT, + user_id INT, + conversation_id INT, + conversation_name VARCHAR(255), + data TEXT, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES auth(id) + ); + `) + if err != nil { + log.Fatal(err) + } +} diff --git a/conversation/conversation.go b/conversation/conversation.go index 2acf110..95459aa 100644 --- a/conversation/conversation.go +++ b/conversation/conversation.go @@ -3,35 +3,51 @@ package conversation import ( "chat/types" "chat/utils" + "database/sql" "errors" ) type Conversation struct { - Username string `json:"username"` - Id int64 `json:"id"` - Message []types.ChatGPTMessage `json:"message"` + UserID int64 `json:"user_id"` + Id int64 `json:"id"` + Name string `json:"name"` + Message []types.ChatGPTMessage `json:"message"` } type FormMessage struct { Message string `json:"message" binding:"required"` } -func NewConversation(username string, id int64) *Conversation { +func NewConversation(db *sql.DB, id int64) *Conversation { return &Conversation{ - Username: username, - Id: id, - Message: []types.ChatGPTMessage{}, + UserID: id, + Id: GetConversationLengthByUserID(db, id), + Name: "new chat", + Message: []types.ChatGPTMessage{}, } } -func (c *Conversation) GetUsername() string { - return c.Username +func (c *Conversation) GetName() string { + return c.Name +} + +func (c *Conversation) SetName(db *sql.DB, name string) { + c.Name = name + c.SaveConversation(db) } func (c *Conversation) GetId() int64 { return c.Id } +func (c *Conversation) GetUserID() int64 { + return c.UserID +} + +func (c *Conversation) SetId(id int64) { + c.Id = id +} + func (c *Conversation) GetMessage() []types.ChatGPTMessage { return c.Message } @@ -56,21 +72,21 @@ func (c *Conversation) AddMessage(message types.ChatGPTMessage) { } func (c *Conversation) AddMessageFromUser(message string) { - c.Message = append(c.Message, types.ChatGPTMessage{ + c.AddMessage(types.ChatGPTMessage{ Role: "user", Content: message, }) } func (c *Conversation) AddMessageFromAssistant(message string) { - c.Message = append(c.Message, types.ChatGPTMessage{ + c.AddMessage(types.ChatGPTMessage{ Role: "assistant", Content: message, }) } func (c *Conversation) AddMessageFromSystem(message string) { - c.Message = append(c.Message, types.ChatGPTMessage{ + c.AddMessage(types.ChatGPTMessage{ Role: "system", Content: message, }) @@ -95,9 +111,16 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) { return "", errors.New("message is empty") } - c.Message = append(c.Message, types.ChatGPTMessage{ - Role: "user", - Content: form.Message, - }) + c.AddMessageFromUser(form.Message) return form.Message, nil } + +func (c *Conversation) HandleMessage(db *sql.DB, data []byte) bool { + _, err := c.AddMessageFromUserForm(data) + if err != nil { + return false + } + + c.SaveConversation(db) + return true +} diff --git a/conversation/storage.go b/conversation/storage.go new file mode 100644 index 0000000..f1ce080 --- /dev/null +++ b/conversation/storage.go @@ -0,0 +1,75 @@ +package conversation + +import ( + "chat/types" + "chat/utils" + "database/sql" +) + +func (c *Conversation) SaveConversation(db *sql.DB) bool { + data := utils.ToJson(c.GetMessage()) + _, err := db.Exec(` + INSERT INTO conversation ( + user_id, + conversation_id, + conversation_name, + data + ) VALUES (?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + conversation_name = VALUES(conversation_name), + data = VALUES(data) + `, c.GetUserID(), c.GetId(), c.GetName(), data) + if err != nil { + return false + } + return true +} + +func GetConversationLengthByUserID(db *sql.DB, userId int64) int64 { + var length int64 + err := db.QueryRow("SELECT COUNT(*) FROM conversation WHERE user_id = ?", userId).Scan(&length) + if err != nil { + return -1 + } + return length +} + +func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversation { + conversation := Conversation{ + UserID: userId, + Id: conversationId, + } + + var data string + err := db.QueryRow("SELECT conversation_name, data FROM conversation WHERE user_id = ? AND conversation_id = ?", userId, conversationId).Scan(&conversation.Name, &data) + if err != nil { + return nil + } + + conversation.Message, err = utils.Unmarshal[[]types.ChatGPTMessage]([]byte(data)) + if err != nil { + return nil + } + + return &conversation +} + +func LoadConversationList(db *sql.DB, userId int64) []Conversation { + var conversationList []Conversation + rows, err := db.Query("SELECT conversation_id, conversation_name FROM conversation WHERE user_id = ?", userId) + if err != nil { + return conversationList + } + defer rows.Close() + + for rows.Next() { + var conversation Conversation + err := rows.Scan(&conversation.Id, &conversation.Name) + if err != nil { + continue + } + conversationList = append(conversationList, conversation) + } + + return conversationList +} diff --git a/types/types.go b/types/types.go index f502380..0a53e98 100644 --- a/types/types.go +++ b/types/types.go @@ -26,3 +26,9 @@ type ChatGPTStreamResponse struct { } `json:"choices"` } `json:"data"` } + +type ChatGPTSegmentResponse struct { + Keyword string `json:"keyword"` + Message string `json:"message"` + End bool `json:"end"` +}