mirror of
https://github.com/coaidev/coai.git
synced 2025-05-23 23:10:13 +09:00
Implemented feature: chatgpt conversation, conversation segment, model selection, websocket cross site validation
This commit is contained in:
parent
f03daa09c8
commit
9f5a4a2ee6
@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/types"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -19,9 +20,9 @@ func GetAnonymousResponse(message string) (string, error) {
|
|||||||
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
|
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": "Bearer " + viper.GetString("openai.anonymous"),
|
"Authorization": "Bearer " + viper.GetString("openai.anonymous"),
|
||||||
}, ChatGPTRequest{
|
}, types.ChatGPTRequest{
|
||||||
Model: "gpt-3.5-turbo-16k",
|
Model: "gpt-3.5-turbo-16k",
|
||||||
Messages: []ChatGPTMessage{
|
Messages: []types.ChatGPTMessage{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: message,
|
Content: message,
|
||||||
|
45
api/chat.go
45
api/chat.go
@ -1,17 +1,29 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/auth"
|
||||||
|
"chat/conversation"
|
||||||
|
"chat/middleware"
|
||||||
|
"chat/utils"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WebsocketAuthForm struct {
|
||||||
|
Token string `json:"token" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
func ChatAPI(c *gin.Context) {
|
func ChatAPI(c *gin.Context) {
|
||||||
// websocket connection
|
// websocket connection
|
||||||
upgrader := websocket.Upgrader{
|
upgrader := websocket.Upgrader{
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
return true
|
origin := c.Request.Header.Get("Origin")
|
||||||
|
if utils.Contains(origin, middleware.AllowedOrigins) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
@ -29,21 +41,30 @@ func ChatAPI(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}(conn)
|
}(conn)
|
||||||
|
|
||||||
|
_, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
form, err := utils.Unmarshal[WebsocketAuthForm](message)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user := auth.ParseToken(c, form.Token)
|
||||||
|
if user == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
instance := conversation.NewConversation(user.Username, user.ID)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
_, message, err := conn.ReadMessage()
|
_, message, err = conn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if _, err := instance.AddMessageFromUserForm(message); err == nil {
|
||||||
var form map[string]interface{}
|
StreamRequest("gpt-3.5-turbo", instance.GetMessageSegment(5), 500, func(resp string) {
|
||||||
if err := json.Unmarshal(message, &form); err == nil {
|
|
||||||
message := form["message"].(string)
|
|
||||||
StreamRequest("gpt-4", []ChatGPTMessage{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: message,
|
|
||||||
},
|
|
||||||
}, 500, func(resp string) {
|
|
||||||
data, _ := json.Marshal(map[string]interface{}{
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
"message": resp,
|
"message": resp,
|
||||||
"end": false,
|
"end": false,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/types"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -26,7 +27,7 @@ func processLine(buf []byte) []string {
|
|||||||
if item == "{data: [DONE]}" {
|
if item == "{data: [DONE]}" {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
var form ChatGPTStreamResponse
|
var form types.ChatGPTStreamResponse
|
||||||
if err := json.Unmarshal([]byte(item), &form); err != nil {
|
if err := json.Unmarshal([]byte(item), &form); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -38,11 +39,11 @@ func processLine(buf []byte) []string {
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func StreamRequest(model string, messages []ChatGPTMessage, token int, callback func(string)) {
|
func StreamRequest(model string, messages []types.ChatGPTMessage, token int, callback func(string)) {
|
||||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(ChatGPTRequest{
|
req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(types.ChatGPTRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
MaxToken: token,
|
MaxToken: token,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
|
|
||||||
export const deploy: boolean = false;
|
export const deploy: boolean = true;
|
||||||
export let rest_api: string = "http://localhost:8094";
|
export let rest_api: string = "http://localhost:8094";
|
||||||
export let ws_api: string = "ws://localhost:8094";
|
export let ws_api: string = "ws://localhost:8094";
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import {nextTick, reactive, ref} from "vue";
|
import {nextTick, reactive, ref} from "vue";
|
||||||
import type { Ref } from "vue";
|
import type { Ref } from "vue";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import {auth} from "./auth";
|
import {auth, token} from "./auth";
|
||||||
import {ws_api} from "./conf";
|
import {ws_api} from "./conf";
|
||||||
|
|
||||||
type Message = {
|
type Message = {
|
||||||
@ -31,6 +31,9 @@ export class Connection {
|
|||||||
this.state = false;
|
this.state = false;
|
||||||
this.connection.onopen = () => {
|
this.connection.onopen = () => {
|
||||||
this.state = true;
|
this.state = true;
|
||||||
|
this.send({
|
||||||
|
token: token.value,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
this.connection.onclose = () => {
|
this.connection.onclose = () => {
|
||||||
this.state = false;
|
this.state = false;
|
||||||
@ -91,13 +94,13 @@ export class Conversation {
|
|||||||
message.value += res.message;
|
message.value += res.message;
|
||||||
end.value = res.end;
|
end.value = res.end;
|
||||||
})
|
})
|
||||||
this.addDynamicMessageFromAI(message, end);
|
|
||||||
const status = this.connection?.send({
|
const status = this.connection?.send({
|
||||||
message: content,
|
message: content,
|
||||||
});
|
});
|
||||||
if (!status) {
|
if (status) {
|
||||||
|
this.addDynamicMessageFromAI(message, end);
|
||||||
|
} else {
|
||||||
this.addMessageFromAI("网络错误,请稍后再试");
|
this.addMessageFromAI("网络错误,请稍后再试");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,7 +26,6 @@ func Validate(token string) *ValidateUserResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
converter, _ := json.Marshal(res)
|
converter, _ := json.Marshal(res)
|
||||||
var response ValidateUserResponse
|
resp, _ := utils.Unmarshal[ValidateUserResponse](converter)
|
||||||
_ = json.Unmarshal(converter, &response)
|
return &resp
|
||||||
return &response
|
|
||||||
}
|
}
|
||||||
|
88
conversation/conversation.go
Normal file
88
conversation/conversation.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/types"
|
||||||
|
"chat/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conversation struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Id int64 `json:"id"`
|
||||||
|
Message []types.ChatGPTMessage `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FormMessage struct {
|
||||||
|
Message string `json:"message" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConversation(username string, id int64) *Conversation {
|
||||||
|
return &Conversation{
|
||||||
|
Username: username,
|
||||||
|
Id: id,
|
||||||
|
Message: []types.ChatGPTMessage{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GetUsername() string {
|
||||||
|
return c.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GetId() int64 {
|
||||||
|
return c.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GetMessage() []types.ChatGPTMessage {
|
||||||
|
return c.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GetMessageSize() int {
|
||||||
|
return len(c.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GetMessageSegment(length int) []types.ChatGPTMessage {
|
||||||
|
if length > len(c.Message) {
|
||||||
|
return c.Message
|
||||||
|
}
|
||||||
|
return c.Message[len(c.Message)-length:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GetLastMessage() types.ChatGPTMessage {
|
||||||
|
return c.Message[len(c.Message)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) AddMessage(message types.ChatGPTMessage) {
|
||||||
|
c.Message = append(c.Message, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) AddMessageFromUser(message string) {
|
||||||
|
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) AddMessageFromAssistant(message string) {
|
||||||
|
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) AddMessageFromSystem(message string) {
|
||||||
|
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||||
|
Role: "system",
|
||||||
|
Content: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
|
||||||
|
form, err := utils.Unmarshal[FormMessage](data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
c.Message = append(c.Message, types.ChatGPTMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: form.Message,
|
||||||
|
})
|
||||||
|
return form.Message, nil
|
||||||
|
}
|
@ -6,7 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var allowedOrigins = []string{
|
var AllowedOrigins = []string{
|
||||||
"https://fystart.cn",
|
"https://fystart.cn",
|
||||||
"https://www.fystart.cn",
|
"https://www.fystart.cn",
|
||||||
"https://nio.fystart.cn",
|
"https://nio.fystart.cn",
|
||||||
@ -16,7 +16,7 @@ var allowedOrigins = []string{
|
|||||||
func CORSMiddleware() gin.HandlerFunc {
|
func CORSMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
origin := c.Request.Header.Get("Origin")
|
origin := c.Request.Header.Get("Origin")
|
||||||
if utils.Contains(origin, allowedOrigins) {
|
if utils.Contains(origin, AllowedOrigins) {
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package api
|
package types
|
||||||
|
|
||||||
type ChatGPTMessage struct {
|
type ChatGPTMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@ -30,3 +31,8 @@ func ConvertTime(t []uint8) *time.Time {
|
|||||||
}
|
}
|
||||||
return &val
|
return &val
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
||||||
|
err = json.Unmarshal(data, &form)
|
||||||
|
return form, err
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user