diff --git a/api/chat.go b/api/chat.go
index 1832302..3e70484 100644
--- a/api/chat.go
+++ b/api/chat.go
@@ -4,7 +4,9 @@ import (
"chat/auth"
"chat/conversation"
"chat/middleware"
+ "chat/types"
"chat/utils"
+ "database/sql"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
@@ -14,6 +16,10 @@ type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"`
}
+func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentResponse) {
+ _ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
+}
+
func ChatAPI(c *gin.Context) {
// websocket connection
upgrader := websocket.Upgrader{
@@ -55,34 +61,25 @@ func ChatAPI(c *gin.Context) {
return
}
- instance := conversation.NewConversation(user.Username, user.ID)
+ db := c.MustGet("db").(*sql.DB)
+ instance := conversation.NewConversation(db, user.ID)
for {
_, message, err = conn.ReadMessage()
if err != nil {
return
}
- if _, err := instance.AddMessageFromUserForm(message); err == nil {
+ if instance.HandleMessage(db, message) {
keyword, segment := ChatWithWeb(instance.GetMessageSegment(12), true)
- _ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(map[string]interface{}{
- "keyword": keyword,
- "message": "",
- "end": false,
- })))
+ SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
- StreamRequest("gpt-3.5-turbo-16k", segment, 2000, func(resp string) {
- data := utils.ToJson(map[string]interface{}{
- "keyword": keyword,
- "message": resp,
- "end": false,
+ StreamRequest("gpt-3.5-turbo-16k-0613", segment, 2000, func(resp string) {
+ SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
+ Message: resp,
+ End: false,
})
- _ = conn.WriteMessage(websocket.TextMessage, []byte(data))
})
- data := utils.ToJson(map[string]interface{}{
- "message": "",
- "end": true,
- })
- _ = conn.WriteMessage(websocket.TextMessage, []byte(data))
+ SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true})
}
}
}
diff --git a/app/src/App.vue b/app/src/App.vue
index 2eaa509..bf2c77c 100644
--- a/app/src/App.vue
+++ b/app/src/App.vue
@@ -6,6 +6,7 @@ import Star from "./components/icons/star.vue";
import { mobile, gpt4 } from "./assets/script/shared";
import Post from "./components/icons/post.vue";
import Github from "./components/icons/github.vue";
+import Heart from "./components/icons/heart.vue";
function goto() {
@@ -40,6 +41,10 @@ function toggle(n: boolean) {
GPT-4
+
+
+ 捐助我们
+
![]()
@@ -155,6 +160,35 @@ aside {
width: max-content;
}
+.donate-container {
+ display: flex;
+ flex-direction: row;
+ align-items: center;
+ margin: 6px auto;
+ padding: 6px 8px;
+ vertical-align: center;
+ border-radius: 8px;
+ gap: 8px;
+ width: 235px;
+ background: rgba(220, 119, 127, 0.25);
+ transition: .25s;
+ cursor: pointer;
+ color: rgb(255, 110, 122);
+ font-size: 16px;
+ justify-content: center;
+ user-select: none;
+}
+
+.donate-container:hover {
+ background: rgba(255, 110, 122, .3);
+}
+
+.donate-container svg {
+ width: 32px;
+ height: 32px;
+ stroke: rgb(255, 110, 122);
+}
+
.model {
display: flex;
align-items: center;
@@ -280,6 +314,10 @@ aside {
flex-direction: column;
}
+ .donate-container {
+ display: none;
+ }
+
.logo {
display: none;
}
diff --git a/app/src/components/icons/heart.vue b/app/src/components/icons/heart.vue
new file mode 100644
index 0000000..57a99fc
--- /dev/null
+++ b/app/src/components/icons/heart.vue
@@ -0,0 +1,3 @@
+
+
+
diff --git a/app/src/views/HomeView.vue b/app/src/views/HomeView.vue
index 340a872..7afebfd 100644
--- a/app/src/views/HomeView.vue
+++ b/app/src/views/HomeView.vue
@@ -160,7 +160,7 @@ onMounted(() => {
.input input {
width: 100%;
height: 32px;
- margin: 4px 16px;
+ margin: 4px 16px 8px;
color: var(--card-text);
background: var(--card-input);
border: 1px solid var(--card-input-border);
diff --git a/connection/database.go b/connection/database.go
index 668b89b..2a2deb7 100644
--- a/connection/database.go
+++ b/connection/database.go
@@ -8,11 +8,11 @@ import (
"log"
)
-var Database *sql.DB
+var _ *sql.DB
func ConnectMySQL() *sql.DB {
// connect to MySQL
- Database, err := sql.Open("mysql", fmt.Sprintf(
+ db, err := sql.Open("mysql", fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s",
viper.GetString("mysql.user"),
viper.GetString("mysql.password"),
@@ -26,11 +26,12 @@ func ConnectMySQL() *sql.DB {
log.Println("Connected to MySQL server successfully")
}
- CreateUserTable(Database)
- CreateSubscriptionTable(Database)
- CreatePackageTable(Database)
- CreatePaymentLogTable(Database)
- return Database
+ CreateUserTable(db)
+ CreateConversationTable(db)
+ CreateSubscriptionTable(db)
+ CreatePackageTable(db)
+ CreatePaymentLogTable(db)
+ return db
}
func CreateUserTable(db *sql.DB) {
@@ -91,3 +92,20 @@ func CreatePackageTable(db *sql.DB) {
log.Fatal(err)
}
}
+
+func CreateConversationTable(db *sql.DB) {
+ _, err := db.Exec(`
+ CREATE TABLE IF NOT EXISTS conversation (
+ id INT PRIMARY KEY AUTO_INCREMENT,
+ user_id INT,
+ conversation_id INT,
+ conversation_name VARCHAR(255),
+ data TEXT,
+ updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES auth(id)
+ );
+ `)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/conversation/conversation.go b/conversation/conversation.go
index 2acf110..95459aa 100644
--- a/conversation/conversation.go
+++ b/conversation/conversation.go
@@ -3,35 +3,51 @@ package conversation
import (
"chat/types"
"chat/utils"
+ "database/sql"
"errors"
)
type Conversation struct {
- Username string `json:"username"`
- Id int64 `json:"id"`
- Message []types.ChatGPTMessage `json:"message"`
+ UserID int64 `json:"user_id"`
+ Id int64 `json:"id"`
+ Name string `json:"name"`
+ Message []types.ChatGPTMessage `json:"message"`
}
type FormMessage struct {
Message string `json:"message" binding:"required"`
}
-func NewConversation(username string, id int64) *Conversation {
+func NewConversation(db *sql.DB, id int64) *Conversation {
return &Conversation{
- Username: username,
- Id: id,
- Message: []types.ChatGPTMessage{},
+ UserID: id,
+ Id: GetConversationLengthByUserID(db, id),
+ Name: "new chat",
+ Message: []types.ChatGPTMessage{},
}
}
-func (c *Conversation) GetUsername() string {
- return c.Username
+func (c *Conversation) GetName() string {
+ return c.Name
+}
+
+func (c *Conversation) SetName(db *sql.DB, name string) {
+ c.Name = name
+ c.SaveConversation(db)
}
func (c *Conversation) GetId() int64 {
return c.Id
}
+func (c *Conversation) GetUserID() int64 {
+ return c.UserID
+}
+
+func (c *Conversation) SetId(id int64) {
+ c.Id = id
+}
+
func (c *Conversation) GetMessage() []types.ChatGPTMessage {
return c.Message
}
@@ -56,21 +72,21 @@ func (c *Conversation) AddMessage(message types.ChatGPTMessage) {
}
func (c *Conversation) AddMessageFromUser(message string) {
- c.Message = append(c.Message, types.ChatGPTMessage{
+ c.AddMessage(types.ChatGPTMessage{
Role: "user",
Content: message,
})
}
func (c *Conversation) AddMessageFromAssistant(message string) {
- c.Message = append(c.Message, types.ChatGPTMessage{
+ c.AddMessage(types.ChatGPTMessage{
Role: "assistant",
Content: message,
})
}
func (c *Conversation) AddMessageFromSystem(message string) {
- c.Message = append(c.Message, types.ChatGPTMessage{
+ c.AddMessage(types.ChatGPTMessage{
Role: "system",
Content: message,
})
@@ -95,9 +111,16 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
return "", errors.New("message is empty")
}
- c.Message = append(c.Message, types.ChatGPTMessage{
- Role: "user",
- Content: form.Message,
- })
+ c.AddMessageFromUser(form.Message)
return form.Message, nil
}
+
+func (c *Conversation) HandleMessage(db *sql.DB, data []byte) bool {
+ _, err := c.AddMessageFromUserForm(data)
+ if err != nil {
+ return false
+ }
+
+ c.SaveConversation(db)
+ return true
+}
diff --git a/conversation/storage.go b/conversation/storage.go
new file mode 100644
index 0000000..f1ce080
--- /dev/null
+++ b/conversation/storage.go
@@ -0,0 +1,75 @@
+package conversation
+
+import (
+ "chat/types"
+ "chat/utils"
+ "database/sql"
+)
+
+func (c *Conversation) SaveConversation(db *sql.DB) bool {
+ data := utils.ToJson(c.GetMessage())
+ _, err := db.Exec(`
+ INSERT INTO conversation (
+ user_id,
+ conversation_id,
+ conversation_name,
+ data
+ ) VALUES (?, ?, ?, ?)
+ ON DUPLICATE KEY UPDATE
+ conversation_name = VALUES(conversation_name),
+ data = VALUES(data)
+ `, c.GetUserID(), c.GetId(), c.GetName(), data)
+ if err != nil {
+ return false
+ }
+ return true
+}
+
+func GetConversationLengthByUserID(db *sql.DB, userId int64) int64 {
+ var length int64
+ err := db.QueryRow("SELECT COUNT(*) FROM conversation WHERE user_id = ?", userId).Scan(&length)
+ if err != nil {
+ return -1
+ }
+ return length
+}
+
+func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversation {
+ conversation := Conversation{
+ UserID: userId,
+ Id: conversationId,
+ }
+
+ var data string
+ err := db.QueryRow("SELECT conversation_name, data FROM conversation WHERE user_id = ? AND conversation_id = ?", userId, conversationId).Scan(&conversation.Name, &data)
+ if err != nil {
+ return nil
+ }
+
+ conversation.Message, err = utils.Unmarshal[[]types.ChatGPTMessage]([]byte(data))
+ if err != nil {
+ return nil
+ }
+
+ return &conversation
+}
+
+func LoadConversationList(db *sql.DB, userId int64) []Conversation {
+ var conversationList []Conversation
+ rows, err := db.Query("SELECT conversation_id, conversation_name FROM conversation WHERE user_id = ?", userId)
+ if err != nil {
+ return conversationList
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var conversation Conversation
+ err := rows.Scan(&conversation.Id, &conversation.Name)
+ if err != nil {
+ continue
+ }
+ conversationList = append(conversationList, conversation)
+ }
+
+ return conversationList
+}
diff --git a/types/types.go b/types/types.go
index f502380..0a53e98 100644
--- a/types/types.go
+++ b/types/types.go
@@ -26,3 +26,9 @@ type ChatGPTStreamResponse struct {
} `json:"choices"`
} `json:"data"`
}
+
+type ChatGPTSegmentResponse struct {
+ Keyword string `json:"keyword"`
+ Message string `json:"message"`
+ End bool `json:"end"`
+}