Implemented feature: chatgpt conversation, conversation segment, model selection, websocket cross site validation

This commit is contained in:
Zhang Minghan 2023-07-23 11:05:46 +08:00
parent f03daa09c8
commit 9f5a4a2ee6
10 changed files with 147 additions and 28 deletions

View File

@ -1,6 +1,7 @@
package api
import (
"chat/types"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
@ -19,9 +20,9 @@ func GetAnonymousResponse(message string) (string, error) {
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer " + viper.GetString("openai.anonymous"),
}, ChatGPTRequest{
}, types.ChatGPTRequest{
Model: "gpt-3.5-turbo-16k",
Messages: []ChatGPTMessage{
Messages: []types.ChatGPTMessage{
{
Role: "user",
Content: message,

View File

@ -1,17 +1,29 @@
package api
import (
"chat/auth"
"chat/conversation"
"chat/middleware"
"chat/utils"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
)
type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"`
}
func ChatAPI(c *gin.Context) {
// websocket connection
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
origin := c.Request.Header.Get("Origin")
if utils.Contains(origin, middleware.AllowedOrigins) {
return true
}
return false
},
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
@ -29,21 +41,30 @@ func ChatAPI(c *gin.Context) {
return
}
}(conn)
_, message, err := conn.ReadMessage()
if err != nil {
return
}
form, err := utils.Unmarshal[WebsocketAuthForm](message)
if err != nil {
return
}
user := auth.ParseToken(c, form.Token)
if user == nil {
return
}
instance := conversation.NewConversation(user.Username, user.ID)
for {
_, message, err := conn.ReadMessage()
_, message, err = conn.ReadMessage()
if err != nil {
return
}
var form map[string]interface{}
if err := json.Unmarshal(message, &form); err == nil {
message := form["message"].(string)
StreamRequest("gpt-4", []ChatGPTMessage{
{
Role: "user",
Content: message,
},
}, 500, func(resp string) {
if _, err := instance.AddMessageFromUserForm(message); err == nil {
StreamRequest("gpt-3.5-turbo", instance.GetMessageSegment(5), 500, func(resp string) {
data, _ := json.Marshal(map[string]interface{}{
"message": resp,
"end": false,

View File

@ -1,6 +1,7 @@
package api
import (
"chat/types"
"chat/utils"
"crypto/tls"
"encoding/json"
@ -26,7 +27,7 @@ func processLine(buf []byte) []string {
if item == "{data: [DONE]}" {
break
}
var form ChatGPTStreamResponse
var form types.ChatGPTStreamResponse
if err := json.Unmarshal([]byte(item), &form); err != nil {
log.Fatal(err)
}
@ -38,11 +39,11 @@ func processLine(buf []byte) []string {
return resp
}
func StreamRequest(model string, messages []ChatGPTMessage, token int, callback func(string)) {
func StreamRequest(model string, messages []types.ChatGPTMessage, token int, callback func(string)) {
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
client := &http.Client{}
req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(ChatGPTRequest{
req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(types.ChatGPTRequest{
Model: model,
Messages: messages,
MaxToken: token,

View File

@ -1,6 +1,6 @@
import axios from "axios";
export const deploy: boolean = false;
export const deploy: boolean = true;
export let rest_api: string = "http://localhost:8094";
export let ws_api: string = "ws://localhost:8094";

View File

@ -1,7 +1,7 @@
import {nextTick, reactive, ref} from "vue";
import type { Ref } from "vue";
import axios from "axios";
import {auth} from "./auth";
import {auth, token} from "./auth";
import {ws_api} from "./conf";
type Message = {
@ -31,6 +31,9 @@ export class Connection {
this.state = false;
this.connection.onopen = () => {
this.state = true;
this.send({
token: token.value,
})
}
this.connection.onclose = () => {
this.state = false;
@ -91,13 +94,13 @@ export class Conversation {
message.value += res.message;
end.value = res.end;
})
this.addDynamicMessageFromAI(message, end);
const status = this.connection?.send({
message: content,
});
if (!status) {
if (status) {
this.addDynamicMessageFromAI(message, end);
} else {
this.addMessageFromAI("网络错误,请稍后再试");
return;
}
}

View File

@ -26,7 +26,6 @@ func Validate(token string) *ValidateUserResponse {
}
converter, _ := json.Marshal(res)
var response ValidateUserResponse
_ = json.Unmarshal(converter, &response)
return &response
resp, _ := utils.Unmarshal[ValidateUserResponse](converter)
return &resp
}

View File

@ -0,0 +1,88 @@
package conversation
import (
"chat/types"
"chat/utils"
)
type Conversation struct {
Username string `json:"username"`
Id int64 `json:"id"`
Message []types.ChatGPTMessage `json:"message"`
}
type FormMessage struct {
Message string `json:"message" binding:"required"`
}
func NewConversation(username string, id int64) *Conversation {
return &Conversation{
Username: username,
Id: id,
Message: []types.ChatGPTMessage{},
}
}
func (c *Conversation) GetUsername() string {
return c.Username
}
func (c *Conversation) GetId() int64 {
return c.Id
}
func (c *Conversation) GetMessage() []types.ChatGPTMessage {
return c.Message
}
func (c *Conversation) GetMessageSize() int {
return len(c.Message)
}
func (c *Conversation) GetMessageSegment(length int) []types.ChatGPTMessage {
if length > len(c.Message) {
return c.Message
}
return c.Message[len(c.Message)-length:]
}
func (c *Conversation) GetLastMessage() types.ChatGPTMessage {
return c.Message[len(c.Message)-1]
}
func (c *Conversation) AddMessage(message types.ChatGPTMessage) {
c.Message = append(c.Message, message)
}
func (c *Conversation) AddMessageFromUser(message string) {
c.Message = append(c.Message, types.ChatGPTMessage{
Role: "user",
Content: message,
})
}
func (c *Conversation) AddMessageFromAssistant(message string) {
c.Message = append(c.Message, types.ChatGPTMessage{
Role: "assistant",
Content: message,
})
}
func (c *Conversation) AddMessageFromSystem(message string) {
c.Message = append(c.Message, types.ChatGPTMessage{
Role: "system",
Content: message,
})
}
func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
form, err := utils.Unmarshal[FormMessage](data)
if err != nil {
return "", err
}
c.Message = append(c.Message, types.ChatGPTMessage{
Role: "user",
Content: form.Message,
})
return form.Message, nil
}

View File

@ -6,7 +6,7 @@ import (
"net/http"
)
var allowedOrigins = []string{
var AllowedOrigins = []string{
"https://fystart.cn",
"https://www.fystart.cn",
"https://nio.fystart.cn",
@ -16,7 +16,7 @@ var allowedOrigins = []string{
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
if utils.Contains(origin, allowedOrigins) {
if utils.Contains(origin, AllowedOrigins) {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")

View File

@ -1,4 +1,4 @@
package api
package types
type ChatGPTMessage struct {
Role string `json:"role"`

View File

@ -1,6 +1,7 @@
package utils
import (
"encoding/json"
"math/rand"
"strconv"
"time"
@ -30,3 +31,8 @@ func ConvertTime(t []uint8) *time.Time {
}
return &val
}
func Unmarshal[T interface{}](data []byte) (form T, err error) {
err = json.Unmarshal(data, &form)
return form, err
}