From 9f5a4a2ee6d296d7ff7417f6fa301feb7278df5f Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Sun, 23 Jul 2023 11:05:46 +0800 Subject: [PATCH] Implemented feature: chatgpt conversation, conversation segment, model selection, websocket cross site validation --- api/anonymous.go | 5 +- api/chat.go | 45 ++++++++++---- api/stream.go | 7 ++- app/src/assets/script/conf.ts | 2 +- app/src/assets/script/conversation.ts | 11 ++-- auth/call.go | 5 +- conversation/conversation.go | 88 +++++++++++++++++++++++++++ middleware/cors.go | 4 +- {api => types}/types.go | 2 +- utils/char.go | 6 ++ 10 files changed, 147 insertions(+), 28 deletions(-) create mode 100644 conversation/conversation.go rename {api => types}/types.go (97%) diff --git a/api/anonymous.go b/api/anonymous.go index f012aeb..3c7fc4a 100644 --- a/api/anonymous.go +++ b/api/anonymous.go @@ -1,6 +1,7 @@ package api import ( + "chat/types" "chat/utils" "fmt" "github.com/gin-gonic/gin" @@ -19,9 +20,9 @@ func GetAnonymousResponse(message string) (string, error) { res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{ "Content-Type": "application/json", "Authorization": "Bearer " + viper.GetString("openai.anonymous"), - }, ChatGPTRequest{ + }, types.ChatGPTRequest{ Model: "gpt-3.5-turbo-16k", - Messages: []ChatGPTMessage{ + Messages: []types.ChatGPTMessage{ { Role: "user", Content: message, diff --git a/api/chat.go b/api/chat.go index dff0e9f..339e47c 100644 --- a/api/chat.go +++ b/api/chat.go @@ -1,17 +1,29 @@ package api import ( + "chat/auth" + "chat/conversation" + "chat/middleware" + "chat/utils" "encoding/json" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "net/http" ) +type WebsocketAuthForm struct { + Token string `json:"token" binding:"required"` +} + func ChatAPI(c *gin.Context) { // websocket connection upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { - return true + origin := c.Request.Header.Get("Origin") + if utils.Contains(origin, middleware.AllowedOrigins) { + return true + } + return false }, } conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) @@ -29,21 +41,30 @@ func ChatAPI(c *gin.Context) { return } }(conn) + + _, message, err := conn.ReadMessage() + if err != nil { + return + } + form, err := utils.Unmarshal[WebsocketAuthForm](message) + if err != nil { + return + } + + user := auth.ParseToken(c, form.Token) + if user == nil { + return + } + + instance := conversation.NewConversation(user.Username, user.ID) + for { - _, message, err := conn.ReadMessage() + _, message, err = conn.ReadMessage() if err != nil { return } - - var form map[string]interface{} - if err := json.Unmarshal(message, &form); err == nil { - message := form["message"].(string) - StreamRequest("gpt-4", []ChatGPTMessage{ - { - Role: "user", - Content: message, - }, - }, 500, func(resp string) { + if _, err := instance.AddMessageFromUserForm(message); err == nil { + StreamRequest("gpt-3.5-turbo", instance.GetMessageSegment(5), 500, func(resp string) { data, _ := json.Marshal(map[string]interface{}{ "message": resp, "end": false, diff --git a/api/stream.go b/api/stream.go index 0cac44b..81b4129 100644 --- a/api/stream.go +++ b/api/stream.go @@ -1,6 +1,7 @@ package api import ( + "chat/types" "chat/utils" "crypto/tls" "encoding/json" @@ -26,7 +27,7 @@ func processLine(buf []byte) []string { if item == "{data: [DONE]}" { break } - var form ChatGPTStreamResponse + var form types.ChatGPTStreamResponse if err := json.Unmarshal([]byte(item), &form); err != nil { log.Fatal(err) } @@ -38,11 +39,11 @@ func processLine(buf []byte) []string { return resp } -func StreamRequest(model string, messages []ChatGPTMessage, token int, callback func(string)) { +func StreamRequest(model string, messages []types.ChatGPTMessage, token int, callback func(string)) { http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} client := &http.Client{} - req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(ChatGPTRequest{ + req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(types.ChatGPTRequest{ Model: model, Messages: messages, MaxToken: token, diff --git a/app/src/assets/script/conf.ts b/app/src/assets/script/conf.ts index 3ef63cb..c239470 100644 --- a/app/src/assets/script/conf.ts +++ b/app/src/assets/script/conf.ts @@ -1,6 +1,6 @@ import axios from "axios"; -export const deploy: boolean = false; +export const deploy: boolean = true; export let rest_api: string = "http://localhost:8094"; export let ws_api: string = "ws://localhost:8094"; diff --git a/app/src/assets/script/conversation.ts b/app/src/assets/script/conversation.ts index 2ee6ee0..1d56465 100644 --- a/app/src/assets/script/conversation.ts +++ b/app/src/assets/script/conversation.ts @@ -1,7 +1,7 @@ import {nextTick, reactive, ref} from "vue"; import type { Ref } from "vue"; import axios from "axios"; -import {auth} from "./auth"; +import {auth, token} from "./auth"; import {ws_api} from "./conf"; type Message = { @@ -31,6 +31,9 @@ export class Connection { this.state = false; this.connection.onopen = () => { this.state = true; + this.send({ + token: token.value, + }) } this.connection.onclose = () => { this.state = false; @@ -91,13 +94,13 @@ export class Conversation { message.value += res.message; end.value = res.end; }) - this.addDynamicMessageFromAI(message, end); const status = this.connection?.send({ message: content, }); - if (!status) { + if (status) { + this.addDynamicMessageFromAI(message, end); + } else { this.addMessageFromAI("网络错误,请稍后再试"); - return; } } diff --git a/auth/call.go b/auth/call.go index dce11cf..ce14304 100644 --- a/auth/call.go +++ b/auth/call.go @@ -26,7 +26,6 @@ func Validate(token string) *ValidateUserResponse { } converter, _ := json.Marshal(res) - var response ValidateUserResponse - _ = json.Unmarshal(converter, &response) - return &response + resp, _ := utils.Unmarshal[ValidateUserResponse](converter) + return &resp } diff --git a/conversation/conversation.go b/conversation/conversation.go new file mode 100644 index 0000000..00c20a1 --- /dev/null +++ b/conversation/conversation.go @@ -0,0 +1,88 @@ +package conversation + +import ( + "chat/types" + "chat/utils" +) + +type Conversation struct { + Username string `json:"username"` + Id int64 `json:"id"` + Message []types.ChatGPTMessage `json:"message"` +} + +type FormMessage struct { + Message string `json:"message" binding:"required"` +} + +func NewConversation(username string, id int64) *Conversation { + return &Conversation{ + Username: username, + Id: id, + Message: []types.ChatGPTMessage{}, + } +} + +func (c *Conversation) GetUsername() string { + return c.Username +} + +func (c *Conversation) GetId() int64 { + return c.Id +} + +func (c *Conversation) GetMessage() []types.ChatGPTMessage { + return c.Message +} + +func (c *Conversation) GetMessageSize() int { + return len(c.Message) +} + +func (c *Conversation) GetMessageSegment(length int) []types.ChatGPTMessage { + if length > len(c.Message) { + return c.Message + } + return c.Message[len(c.Message)-length:] +} + +func (c *Conversation) GetLastMessage() types.ChatGPTMessage { + return c.Message[len(c.Message)-1] +} + +func (c *Conversation) AddMessage(message types.ChatGPTMessage) { + c.Message = append(c.Message, message) +} + +func (c *Conversation) AddMessageFromUser(message string) { + c.Message = append(c.Message, types.ChatGPTMessage{ + Role: "user", + Content: message, + }) +} + +func (c *Conversation) AddMessageFromAssistant(message string) { + c.Message = append(c.Message, types.ChatGPTMessage{ + Role: "assistant", + Content: message, + }) +} + +func (c *Conversation) AddMessageFromSystem(message string) { + c.Message = append(c.Message, types.ChatGPTMessage{ + Role: "system", + Content: message, + }) +} + +func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) { + form, err := utils.Unmarshal[FormMessage](data) + if err != nil { + return "", err + } + c.Message = append(c.Message, types.ChatGPTMessage{ + Role: "user", + Content: form.Message, + }) + return form.Message, nil +} diff --git a/middleware/cors.go b/middleware/cors.go index 912b443..54bc7a1 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -6,7 +6,7 @@ import ( "net/http" ) -var allowedOrigins = []string{ +var AllowedOrigins = []string{ "https://fystart.cn", "https://www.fystart.cn", "https://nio.fystart.cn", @@ -16,7 +16,7 @@ var allowedOrigins = []string{ func CORSMiddleware() gin.HandlerFunc { return func(c *gin.Context) { origin := c.Request.Header.Get("Origin") - if utils.Contains(origin, allowedOrigins) { + if utils.Contains(origin, AllowedOrigins) { c.Writer.Header().Set("Access-Control-Allow-Origin", origin) c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") diff --git a/api/types.go b/types/types.go similarity index 97% rename from api/types.go rename to types/types.go index 1a22c66..f502380 100644 --- a/api/types.go +++ b/types/types.go @@ -1,4 +1,4 @@ -package api +package types type ChatGPTMessage struct { Role string `json:"role"` diff --git a/utils/char.go b/utils/char.go index fae5274..1df77df 100644 --- a/utils/char.go +++ b/utils/char.go @@ -1,6 +1,7 @@ package utils import ( + "encoding/json" "math/rand" "strconv" "time" @@ -30,3 +31,8 @@ func ConvertTime(t []uint8) *time.Time { } return &val } + +func Unmarshal[T interface{}](data []byte) (form T, err error) { + err = json.Unmarshal(data, &form) + return form, err +}