mirror of
https://github.com/coaidev/coai.git
synced 2025-05-23 23:10:13 +09:00
Implemented feature: chatgpt conversation, conversation segment, model selection, websocket cross site validation
This commit is contained in:
parent
f03daa09c8
commit
9f5a4a2ee6
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -19,9 +20,9 @@ func GetAnonymousResponse(message string) (string, error) {
|
||||
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + viper.GetString("openai.anonymous"),
|
||||
}, ChatGPTRequest{
|
||||
}, types.ChatGPTRequest{
|
||||
Model: "gpt-3.5-turbo-16k",
|
||||
Messages: []ChatGPTMessage{
|
||||
Messages: []types.ChatGPTMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
|
41
api/chat.go
41
api/chat.go
@ -1,17 +1,29 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"chat/conversation"
|
||||
"chat/middleware"
|
||||
"chat/utils"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type WebsocketAuthForm struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
}
|
||||
|
||||
func ChatAPI(c *gin.Context) {
|
||||
// websocket connection
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if utils.Contains(origin, middleware.AllowedOrigins) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
@ -29,21 +41,30 @@ func ChatAPI(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
for {
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
form, err := utils.Unmarshal[WebsocketAuthForm](message)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var form map[string]interface{}
|
||||
if err := json.Unmarshal(message, &form); err == nil {
|
||||
message := form["message"].(string)
|
||||
StreamRequest("gpt-4", []ChatGPTMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
},
|
||||
}, 500, func(resp string) {
|
||||
user := auth.ParseToken(c, form.Token)
|
||||
if user == nil {
|
||||
return
|
||||
}
|
||||
|
||||
instance := conversation.NewConversation(user.Username, user.ID)
|
||||
|
||||
for {
|
||||
_, message, err = conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if _, err := instance.AddMessageFromUserForm(message); err == nil {
|
||||
StreamRequest("gpt-3.5-turbo", instance.GetMessageSegment(5), 500, func(resp string) {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"message": resp,
|
||||
"end": false,
|
||||
|
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
@ -26,7 +27,7 @@ func processLine(buf []byte) []string {
|
||||
if item == "{data: [DONE]}" {
|
||||
break
|
||||
}
|
||||
var form ChatGPTStreamResponse
|
||||
var form types.ChatGPTStreamResponse
|
||||
if err := json.Unmarshal([]byte(item), &form); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@ -38,11 +39,11 @@ func processLine(buf []byte) []string {
|
||||
return resp
|
||||
}
|
||||
|
||||
func StreamRequest(model string, messages []ChatGPTMessage, token int, callback func(string)) {
|
||||
func StreamRequest(model string, messages []types.ChatGPTMessage, token int, callback func(string)) {
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(ChatGPTRequest{
|
||||
req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(types.ChatGPTRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
MaxToken: token,
|
||||
|
@ -1,6 +1,6 @@
|
||||
import axios from "axios";
|
||||
|
||||
export const deploy: boolean = false;
|
||||
export const deploy: boolean = true;
|
||||
export let rest_api: string = "http://localhost:8094";
|
||||
export let ws_api: string = "ws://localhost:8094";
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import {nextTick, reactive, ref} from "vue";
|
||||
import type { Ref } from "vue";
|
||||
import axios from "axios";
|
||||
import {auth} from "./auth";
|
||||
import {auth, token} from "./auth";
|
||||
import {ws_api} from "./conf";
|
||||
|
||||
type Message = {
|
||||
@ -31,6 +31,9 @@ export class Connection {
|
||||
this.state = false;
|
||||
this.connection.onopen = () => {
|
||||
this.state = true;
|
||||
this.send({
|
||||
token: token.value,
|
||||
})
|
||||
}
|
||||
this.connection.onclose = () => {
|
||||
this.state = false;
|
||||
@ -91,13 +94,13 @@ export class Conversation {
|
||||
message.value += res.message;
|
||||
end.value = res.end;
|
||||
})
|
||||
this.addDynamicMessageFromAI(message, end);
|
||||
const status = this.connection?.send({
|
||||
message: content,
|
||||
});
|
||||
if (!status) {
|
||||
if (status) {
|
||||
this.addDynamicMessageFromAI(message, end);
|
||||
} else {
|
||||
this.addMessageFromAI("网络错误,请稍后再试");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,6 @@ func Validate(token string) *ValidateUserResponse {
|
||||
}
|
||||
|
||||
converter, _ := json.Marshal(res)
|
||||
var response ValidateUserResponse
|
||||
_ = json.Unmarshal(converter, &response)
|
||||
return &response
|
||||
resp, _ := utils.Unmarshal[ValidateUserResponse](converter)
|
||||
return &resp
|
||||
}
|
||||
|
88
conversation/conversation.go
Normal file
88
conversation/conversation.go
Normal file
@ -0,0 +1,88 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
)
|
||||
|
||||
type Conversation struct {
|
||||
Username string `json:"username"`
|
||||
Id int64 `json:"id"`
|
||||
Message []types.ChatGPTMessage `json:"message"`
|
||||
}
|
||||
|
||||
type FormMessage struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
}
|
||||
|
||||
func NewConversation(username string, id int64) *Conversation {
|
||||
return &Conversation{
|
||||
Username: username,
|
||||
Id: id,
|
||||
Message: []types.ChatGPTMessage{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conversation) GetUsername() string {
|
||||
return c.Username
|
||||
}
|
||||
|
||||
func (c *Conversation) GetId() int64 {
|
||||
return c.Id
|
||||
}
|
||||
|
||||
func (c *Conversation) GetMessage() []types.ChatGPTMessage {
|
||||
return c.Message
|
||||
}
|
||||
|
||||
func (c *Conversation) GetMessageSize() int {
|
||||
return len(c.Message)
|
||||
}
|
||||
|
||||
func (c *Conversation) GetMessageSegment(length int) []types.ChatGPTMessage {
|
||||
if length > len(c.Message) {
|
||||
return c.Message
|
||||
}
|
||||
return c.Message[len(c.Message)-length:]
|
||||
}
|
||||
|
||||
func (c *Conversation) GetLastMessage() types.ChatGPTMessage {
|
||||
return c.Message[len(c.Message)-1]
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessage(message types.ChatGPTMessage) {
|
||||
c.Message = append(c.Message, message)
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromUser(message string) {
|
||||
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromAssistant(message string) {
|
||||
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||
Role: "assistant",
|
||||
Content: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromSystem(message string) {
|
||||
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||
Role: "system",
|
||||
Content: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
|
||||
form, err := utils.Unmarshal[FormMessage](data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||
Role: "user",
|
||||
Content: form.Message,
|
||||
})
|
||||
return form.Message, nil
|
||||
}
|
@ -6,7 +6,7 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var allowedOrigins = []string{
|
||||
var AllowedOrigins = []string{
|
||||
"https://fystart.cn",
|
||||
"https://www.fystart.cn",
|
||||
"https://nio.fystart.cn",
|
||||
@ -16,7 +16,7 @@ var allowedOrigins = []string{
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if utils.Contains(origin, allowedOrigins) {
|
||||
if utils.Contains(origin, AllowedOrigins) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
||||
|
@ -1,4 +1,4 @@
|
||||
package api
|
||||
package types
|
||||
|
||||
type ChatGPTMessage struct {
|
||||
Role string `json:"role"`
|
@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -30,3 +31,8 @@ func ConvertTime(t []uint8) *time.Time {
|
||||
}
|
||||
return &val
|
||||
}
|
||||
|
||||
func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
||||
err = json.Unmarshal(data, &form)
|
||||
return form, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user