mirror of
https://github.com/coaidev/coai.git
synced 2025-05-21 22:10:12 +09:00
v3 restruct
This commit is contained in:
parent
51f025b0f4
commit
6ac6e784ef
4
.gitignore
vendored
4
.gitignore
vendored
@ -2,8 +2,8 @@ node_modules
|
||||
.vscode
|
||||
.idea
|
||||
config.yaml
|
||||
generation/data/*
|
||||
!generation/data/.gitkeep
|
||||
addition/generation/data/*
|
||||
!addition/generation/data/.gitkeep
|
||||
|
||||
chat
|
||||
chat.exe
|
||||
|
39
adapter/adapter.go
Normal file
39
adapter/adapter.go
Normal file
@ -0,0 +1,39 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"chat/adapter/chatgpt"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type ChatProps struct {
|
||||
Model string
|
||||
Reversible bool
|
||||
Infinity bool
|
||||
Message []globals.Message
|
||||
}
|
||||
|
||||
type Hook func(data string) error
|
||||
|
||||
func NewChatRequest(props *ChatProps, hook Hook) error {
|
||||
if globals.IsClaudeModel(props.Model) {
|
||||
return nil // work in progress
|
||||
} else if globals.IsChatGPTModel(props.Model) {
|
||||
instance := chatgpt.NewChatInstanceFromModel(&chatgpt.InstanceProps{
|
||||
Model: props.Model,
|
||||
Reversible: props.Reversible,
|
||||
})
|
||||
return instance.CreateStreamChatRequest(&chatgpt.ChatProps{
|
||||
Model: utils.Multi(
|
||||
props.Reversible && globals.IsGPT4NativeModel(props.Model),
|
||||
viper.GetString("openai.reverse.hash"),
|
||||
props.Model,
|
||||
),
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(globals.IsGPT4Model(props.Model) || props.Reversible || props.Infinity, -1, 2000),
|
||||
}, hook)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
140
adapter/chatgpt/chat.go
Normal file
140
adapter/chatgpt/chat.go
Normal file
@ -0,0 +1,140 @@
|
||||
package chatgpt
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatProps struct {
|
||||
Model string
|
||||
Message []globals.Message
|
||||
Token int
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetChatEndpoint() string {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", c.GetEndpoint())
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
|
||||
if props.Token != -1 {
|
||||
return ChatRequest{
|
||||
Model: props.Model,
|
||||
Messages: props.Message,
|
||||
MaxToken: props.Token,
|
||||
Stream: stream,
|
||||
}
|
||||
}
|
||||
|
||||
return ChatRequestWithInfinity{
|
||||
Model: props.Model,
|
||||
Messages: props.Message,
|
||||
Stream: stream,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) ProcessLine(data string) string {
|
||||
rep := strings.NewReplacer(
|
||||
"data: {",
|
||||
"\"data\": {",
|
||||
)
|
||||
item := rep.Replace(data)
|
||||
if !strings.HasPrefix(item, "{") {
|
||||
item = "{" + item
|
||||
}
|
||||
if !strings.HasSuffix(item, "}}") {
|
||||
item = item + "}"
|
||||
}
|
||||
|
||||
if item == "{data: [DONE]}" || item == "{data: [DONE]}}" || item == "{[DONE]}" {
|
||||
return ""
|
||||
} else if item == "{data:}" || item == "{data:}}" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var form *ChatStreamResponse
|
||||
if form = utils.UnmarshalForm[ChatStreamResponse](item); form == nil {
|
||||
if form = utils.UnmarshalForm[ChatStreamResponse](item); form == nil {
|
||||
return fmt.Sprintf("%s\n", item)
|
||||
}
|
||||
}
|
||||
|
||||
if len(form.Data.Choices) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return form.Data.Choices[0].Delta.Content
|
||||
}
|
||||
|
||||
// CreateChatRequest is the native http request body for chatgpt
|
||||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
res, err := utils.Post(
|
||||
c.GetChatEndpoint(),
|
||||
c.GetHeader(),
|
||||
c.GetChatBody(props, false),
|
||||
)
|
||||
|
||||
if err != nil || res == nil {
|
||||
return "", fmt.Errorf("chatgpt error: %s", err.Error())
|
||||
}
|
||||
|
||||
data := utils.MapToStruct[ChatResponse](res)
|
||||
if data == nil {
|
||||
return "", fmt.Errorf("chatgpt error: cannot parse response")
|
||||
} else if data.Error.Message != "" {
|
||||
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message)
|
||||
}
|
||||
return data.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// CreateStreamChatRequest is the stream response body for chatgpt
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback func(string) error) error {
|
||||
return utils.EventSource(
|
||||
"POST",
|
||||
c.GetChatEndpoint(),
|
||||
c.GetHeader(),
|
||||
c.GetChatBody(props, true),
|
||||
func(data string) error {
|
||||
if data := c.ProcessLine(data); data != "" {
|
||||
if err := callback(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) Test() bool {
|
||||
result, err := c.CreateChatRequest(&ChatProps{
|
||||
Model: globals.GPT3Turbo,
|
||||
Message: []globals.Message{{Role: "user", Content: "hi"}},
|
||||
Token: 1,
|
||||
})
|
||||
|
||||
return err == nil && len(result) > 0
|
||||
}
|
||||
|
||||
func FilterKeys(v string) string {
|
||||
endpoint := viper.GetString(fmt.Sprintf("openai.%s.endpoint", v))
|
||||
keys := strings.Split(viper.GetString(fmt.Sprintf("openai.%s.apikey", v)), "|")
|
||||
|
||||
stack := make(chan string, len(keys))
|
||||
for _, key := range keys {
|
||||
go func(key string) {
|
||||
instance := NewChatInstance(endpoint, key)
|
||||
stack <- utils.Multi[string](instance.Test(), key, "")
|
||||
}(key)
|
||||
}
|
||||
|
||||
var result []string
|
||||
for i := 0; i < len(keys); i++ {
|
||||
if res := <-stack; res != "" {
|
||||
result = append(result, res)
|
||||
}
|
||||
}
|
||||
return strings.Join(result, "|")
|
||||
}
|
38
adapter/chatgpt/dalle.go
Normal file
38
adapter/chatgpt/dalle.go
Normal file
@ -0,0 +1,38 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type ImageProps struct {
|
||||
Prompt string
|
||||
Size ImageSize
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetImageEndpoint() string {
|
||||
return fmt.Sprintf("%s/v1/images/generations", c.GetEndpoint())
|
||||
}
|
||||
|
||||
// CreateImage will create a dalle image from prompt, return url of image and error
|
||||
func (c *ChatInstance) CreateImage(props ImageProps) (string, error) {
|
||||
res, err := utils.Post(
|
||||
c.GetImageEndpoint(),
|
||||
c.GetHeader(), ImageRequest{
|
||||
Prompt: props.Prompt,
|
||||
Size: utils.Multi[ImageSize](len(props.Size) == 0, ImageSize512, props.Size),
|
||||
N: 1,
|
||||
})
|
||||
if err != nil || res == nil {
|
||||
return "", fmt.Errorf("chatgpt error: %s", err.Error())
|
||||
}
|
||||
|
||||
data := utils.MapToStruct[ImageResponse](res)
|
||||
if data == nil {
|
||||
return "", fmt.Errorf("chatgpt error: cannot parse response")
|
||||
} else if data.Error.Message != "" {
|
||||
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message)
|
||||
}
|
||||
|
||||
return data.Data[0].Url, nil
|
||||
}
|
72
adapter/chatgpt/struct.go
Normal file
72
adapter/chatgpt/struct.go
Normal file
@ -0,0 +1,72 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type ChatInstance struct {
|
||||
Endpoint string
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
type InstanceProps struct {
|
||||
Model string
|
||||
Reversible bool
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetEndpoint() string {
|
||||
return c.Endpoint
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetApiKey() string {
|
||||
return c.ApiKey
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetHeader() map[string]string {
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.GetApiKey()),
|
||||
}
|
||||
}
|
||||
|
||||
func NewChatInstance(endpoint, apiKey string) *ChatInstance {
|
||||
return &ChatInstance{
|
||||
Endpoint: endpoint,
|
||||
ApiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func NewChatInstanceFromConfig(v string) *ChatInstance {
|
||||
return NewChatInstance(
|
||||
viper.GetString(fmt.Sprintf("openai.%s.endpoint", v)),
|
||||
utils.GetRandomKey(viper.GetString(fmt.Sprintf("openai.%s.apikey", v))),
|
||||
)
|
||||
}
|
||||
|
||||
func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance {
|
||||
switch props.Model {
|
||||
case globals.GPT4,
|
||||
globals.GPT40314,
|
||||
globals.GPT40613:
|
||||
if props.Reversible {
|
||||
return NewChatInstanceFromConfig("reverse")
|
||||
} else {
|
||||
return NewChatInstanceFromConfig("gpt4")
|
||||
}
|
||||
|
||||
case globals.GPT432k,
|
||||
globals.GPT432k0613,
|
||||
globals.GPT432k0314:
|
||||
return NewChatInstanceFromConfig("gpt4")
|
||||
|
||||
case globals.GPT3Turbo16k,
|
||||
globals.GPT3Turbo16k0301,
|
||||
globals.GPT3Turbo16k0613:
|
||||
return NewChatInstanceFromConfig("gpt3")
|
||||
default:
|
||||
return NewChatInstanceFromConfig("gpt3")
|
||||
}
|
||||
}
|
73
adapter/chatgpt/types.go
Normal file
73
adapter/chatgpt/types.go
Normal file
@ -0,0 +1,73 @@
|
||||
package chatgpt
|
||||
|
||||
import "chat/globals"
|
||||
|
||||
// ChatRequest is the request body for chatgpt
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []globals.Message `json:"messages"`
|
||||
MaxToken int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ChatRequestWithInfinity struct {
|
||||
Model string `json:"model"`
|
||||
Messages []globals.Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// ChatResponse is the native http request body for chatgpt
|
||||
type ChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
} `json:"choices"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// ChatStreamResponse is the stream response body for chatgpt
|
||||
type ChatStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Data struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
Index int `json:"index"`
|
||||
} `json:"choices"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type ImageSize string
|
||||
|
||||
// ImageRequest is the request body for chatgpt dalle image generation
|
||||
type ImageRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Size ImageSize `json:"size"`
|
||||
N int `json:"n"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Data []struct {
|
||||
Url string `json:"url"`
|
||||
} `json:"data"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
var (
|
||||
ImageSize256 ImageSize = "256x256"
|
||||
ImageSize512 ImageSize = "512x512"
|
||||
ImageSize1024 ImageSize = "1024x1024"
|
||||
)
|
@ -1,13 +1,19 @@
|
||||
package api
|
||||
package card
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/globals"
|
||||
"chat/manager"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/russross/blackfriday/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RequestForm struct {
|
||||
Message string `json:"message" required:"true"`
|
||||
Web bool `json:"web"`
|
||||
}
|
||||
|
||||
const maxColumnPerLine = 50
|
||||
|
||||
func ProcessMarkdownLine(source []byte) string {
|
||||
@ -43,8 +49,8 @@ func MarkdownConvert(text string) string {
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func CardAPI(c *gin.Context) {
|
||||
var body AnonymousRequestBody
|
||||
func HandlerAPI(c *gin.Context) {
|
||||
var body RequestForm
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid request body",
|
||||
@ -58,13 +64,13 @@ func CardAPI(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
key, res, err := GetAnonymousResponseWithCache(c, message, body.Web, []types.ChatGPTMessage{})
|
||||
if err != nil {
|
||||
res = "There was something wrong..."
|
||||
}
|
||||
keyword, response, quota := manager.NativeChatHandler(c, nil, globals.GPT3Turbo0613, []globals.Message{
|
||||
{Role: "user", Content: message},
|
||||
}, body.Web)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": MarkdownConvert(res),
|
||||
"keyword": key,
|
||||
"message": MarkdownConvert(response),
|
||||
"keyword": keyword,
|
||||
"quota": quota,
|
||||
})
|
||||
}
|
Before Width: | Height: | Size: 8.7 KiB After Width: | Height: | Size: 8.7 KiB |
102
addition/generation/api.go
Normal file
102
addition/generation/api.go
Normal file
@ -0,0 +1,102 @@
|
||||
package generation
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WebsocketGenerationForm struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Model string `json:"model" binding:"required"`
|
||||
}
|
||||
|
||||
func ProjectTarDownloadAPI(c *gin.Context) {
|
||||
hash := strings.TrimSpace(c.Query("hash"))
|
||||
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.tar.gz")
|
||||
c.File(fmt.Sprintf("generation/data/out/%s.tar.gz", hash))
|
||||
}
|
||||
|
||||
func ProjectZipDownloadAPI(c *gin.Context) {
|
||||
hash := strings.TrimSpace(c.Query("hash"))
|
||||
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.zip")
|
||||
c.File(fmt.Sprintf("generation/data/out/%s.zip", hash))
|
||||
}
|
||||
|
||||
func GenerateAPI(c *gin.Context) {
|
||||
var conn *utils.WebSocket
|
||||
if conn = utils.NewWebsocket(c); conn == nil {
|
||||
return
|
||||
}
|
||||
defer conn.DeferClose()
|
||||
|
||||
var form *WebsocketGenerationForm
|
||||
if form = utils.ReadForm[WebsocketGenerationForm](conn); form == nil {
|
||||
return
|
||||
}
|
||||
|
||||
user := auth.ParseToken(c, form.Token)
|
||||
authenticated := user != nil
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
id := auth.GetId(db, user)
|
||||
|
||||
if !utils.IncrWithLimit(cache,
|
||||
fmt.Sprintf(":generation:%s", utils.Multi[string](authenticated, strconv.FormatInt(id, 10), c.ClientIP())),
|
||||
1,
|
||||
30,
|
||||
3600,
|
||||
) {
|
||||
conn.Send(globals.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Error: "generation rate limit exceeded, the max generation rate is 30 per hour.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
reversible := auth.CanEnableSubscription(db, cache, user)
|
||||
if !auth.CanEnableModelWithSubscription(db, user, form.Model, reversible) {
|
||||
conn.Send(globals.GenerationSegmentResponse{
|
||||
Message: "You don't have enough quota to use this model.",
|
||||
Quota: 0,
|
||||
End: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var instance *utils.Buffer
|
||||
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, reversible, func(buffer *utils.Buffer, data string) {
|
||||
instance = buffer
|
||||
conn.Send(globals.GenerationSegmentResponse{
|
||||
End: false,
|
||||
Message: data,
|
||||
Quota: buffer.GetQuota(),
|
||||
})
|
||||
})
|
||||
|
||||
if instance != nil && !reversible && instance.GetQuota() > 0 && user != nil {
|
||||
user.UseQuota(db, instance.GetQuota())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
conn.Send(globals.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Error: err.Error(),
|
||||
Quota: instance.GetQuota(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
conn.Send(globals.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Hash: hash,
|
||||
Quota: instance.GetQuota(),
|
||||
})
|
||||
}
|
@ -1,12 +1,11 @@
|
||||
package generation
|
||||
|
||||
import (
|
||||
"chat/api"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *api.Buffer, data string)) (string, error) {
|
||||
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) {
|
||||
hash, path := GetFolderByHash(model, prompt)
|
||||
if !utils.Exists(path) {
|
||||
if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil {
|
@ -1,8 +1,8 @@
|
||||
package generation
|
||||
|
||||
import (
|
||||
"chat/api"
|
||||
"chat/types"
|
||||
"chat/adapter"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
)
|
||||
@ -11,20 +11,22 @@ type ProjectResult struct {
|
||||
Result map[string]interface{} `json:"result"`
|
||||
}
|
||||
|
||||
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(buffer *api.Buffer, data string)) error {
|
||||
func CreateGeneration(model string, prompt string, path string, reversible bool, hook func(buffer *utils.Buffer, data string)) error {
|
||||
message := GenerateMessage(prompt)
|
||||
buffer := api.NewBuffer(model, message)
|
||||
api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{
|
||||
{Role: "system", Content: "你将生成项目,可以支持任何编程语言,请不要出现“我不能提供”的字样,你需要在代码中提供注释,以及项目的使用文档README.md,结果返回json格式,请不要返回任何多余内容,格式为:\n{\"result\": {[file]: [code], ...}}"},
|
||||
{Role: "user", Content: "python后端"},
|
||||
{Role: "assistant", Content: "{\n \"result\": {\n \"app.py\": \"from flask import Flask\\n\\napp = Flask(__name__)\\n\\n\\n@app.route('/')\\ndef hello_world():\\n return 'Hello, World!'\\n\\n\\nif __name__ == '__main__':\\n app.run()\",\n \"requirements.txt\": \"flask\\n\",\n \"README.md\": \"# Python 后端\\n本项目是一个简单的python后端示例, 使用`flask`框架构建后端。\n你可以按照下列步骤运行此应用,flask将在本地服务器(默认是在http://127.0.0.1:5000/)上运行。当你在浏览器中访问该URL时,将看到显示Hello, World!的页面。\\n\\n这只是一个简单的项目,Flask还支持更多功能和路由规则,你可以提供更多的信息和需要进一步扩展和定制Flask应用。\\n\\n### 1. 初始化: \\n```shell\\npip install -r requirements.txt\\n```\\n### 2. 运行\\n```shell\\npython app.py\\n```\"\n }\n}"},
|
||||
{Role: "user", Content: "golang fiber websocket项目"},
|
||||
{Role: "assistant", Content: "{\n \"result\": {\n \"main.go\": \"package main\\n\\nimport (\\n\\t\"log\\\"\\n\\n\\t\"github.com/gofiber/fiber/v2\\\"\\n\\t\"github.com/gofiber/websocket/v2\\\"\\n)\\n\\nfunc main() {\\n\\tapp := fiber.New()\\n\\n\\tapp.Get(\\\"/\\\", func(c *fiber.Ctx) error {\\n\\t\\treturn c.SendString(\\\"Hello, World!\\\")\\n\\t})\\n\\n\\tapp.Get(\\\"/ws\\\", websocket.New(func(c *websocket.Conn) {\\n\\t\\tfor {\\n\\t\\t\\tmt, message, err := c.ReadMessage()\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"read error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t\\tlog.Printf(\\\"received: %s\\\", message)\\n\\t\\t\\terr = c.WriteMessage(mt, message)\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"write error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t}\\n\\t}))\\n\\n\\tlog.Fatal(app.Listen(\\\":3000\\\"))\\n}\",\n \"go.mod\": \"module fiber-websocket\\n\\ngo 1.16\\n\\nrequire (\\n\\tgithub.com/gofiber/fiber/v2 v2.12.1\\n\\tgithub.com/gofiber/websocket/v2 v2.10.2\\n)\",\n \"README.md\": \"# Golang Fiber WebSocket项目\\n\\n这个项目是一个使用 Golang 和 Fiber 框架构建的 WebSocket 服务器示例。\\n\\n### 运行应用程序:\\n```shell\\ngo run main.go\\n```\\n\\n应用程序将在本地服务器(默认是在http://localhost:3000)上运行。当你在浏览器中访问`http://localhost:3000`时,将看到显示\"Hello, World!\"的页面。你还可以访问`http://localhost:3000/ws`来测试 WebSocket 连接。\n\n注意:在运行应用程序之前,请确保已经安装了Go语言开发环境。\"\n }\n}"},
|
||||
{Role: "user", Content: prompt},
|
||||
}, -1, func(data string) {
|
||||
buffer := utils.NewBuffer(model, message)
|
||||
|
||||
if err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Message: message,
|
||||
Reversible: reversible && globals.IsGPT4Model(model),
|
||||
Infinity: true,
|
||||
}, func(data string) error {
|
||||
buffer.Write(data)
|
||||
hook(buffer, data)
|
||||
})
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := utils.Unmarshal[ProjectResult](buffer.ReadBytes())
|
||||
if err != nil {
|
||||
@ -37,8 +39,8 @@ func CreateGeneration(model string, prompt string, path string, enableReverse bo
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateMessage(prompt string) []types.ChatGPTMessage {
|
||||
return []types.ChatGPTMessage{
|
||||
func GenerateMessage(prompt string) []globals.Message {
|
||||
return []globals.Message{
|
||||
{Role: "system", Content: "你将生成项目,可以支持任何编程语言,请不要出现“我不能提供”的字样,你需要在代码中提供注释,以及项目的使用文档README.md,结果返回json格式,请不要返回任何多余内容,格式为:\n{\"result\": {[file]: [code], ...}}"},
|
||||
{Role: "user", Content: "python后端"},
|
||||
{Role: "assistant", Content: "{\n \"result\": {\n \"app.py\": \"from flask import Flask\\n\\napp = Flask(__name__)\\n\\n\\n@app.route('/')\\ndef hello_world():\\n return 'Hello, World!'\\n\\n\\nif __name__ == '__main__':\\n app.run()\",\n \"requirements.txt\": \"flask\\n\",\n \"README.md\": \"# Python 后端\\n本项目是一个简单的python后端示例, 使用`flask`框架构建后端。\n你可以按照下列步骤运行此应用,flask将在本地服务器(默认是在http://127.0.0.1:5000/)上运行。当你在浏览器中访问该URL时,将看到显示Hello, World!的页面。\\n\\n这只是一个简单的项目,Flask还支持更多功能和路由规则,你可以提供更多的信息和需要进一步扩展和定制Flask应用。\\n\\n### 1. 初始化: \\n```shell\\npip install -r requirements.txt\\n```\\n### 2. 运行\\n```shell\\npython app.py\\n```\"\n }\n}"},
|
17
addition/router.go
Normal file
17
addition/router.go
Normal file
@ -0,0 +1,17 @@
|
||||
package addition
|
||||
|
||||
import (
|
||||
"chat/addition/card"
|
||||
"chat/addition/generation"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Register(app *gin.Engine) {
|
||||
{
|
||||
app.POST("/card", card.HandlerAPI)
|
||||
|
||||
app.GET("/generation/create", generation.GenerateAPI)
|
||||
app.GET("/generation/download/tar", generation.ProjectTarDownloadAPI)
|
||||
app.GET("/generation/download/zip", generation.ProjectZipDownloadAPI)
|
||||
}
|
||||
}
|
@ -1,29 +1,29 @@
|
||||
package api
|
||||
package web
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"chat/web"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ChatWithWeb(message []types.ChatGPTMessage, long bool) (string, []types.ChatGPTMessage) {
|
||||
keyword := strings.TrimSpace(SearchWeb(message))
|
||||
type Hook func(message []globals.Message, token int) (string, error)
|
||||
|
||||
func ChatWithWeb(hook Hook, message []globals.Message, long bool) (string, []globals.Message) {
|
||||
keyword := strings.TrimSpace(GetKeywordPoint(hook, message))
|
||||
|
||||
if len(keyword) == 0 {
|
||||
return keyword, message
|
||||
}
|
||||
|
||||
data := web.SearchBing(keyword)
|
||||
data := SearchBing(keyword)
|
||||
if long {
|
||||
data = utils.GetSegmentString(data, 6000)
|
||||
} else {
|
||||
data = utils.GetSegmentString(data, 3000)
|
||||
}
|
||||
return keyword, utils.Insert(message, 0, types.ChatGPTMessage{
|
||||
return keyword, utils.Insert(message, 0, globals.Message{
|
||||
Role: "system",
|
||||
Content: fmt.Sprintf("你将扮演AI问答助手,你的知识库不是截止于2021年9月,而是可以实时联网的,你可以提供实时联网的知识库。"+
|
||||
"网站使用markdown链接包裹,多使用emoji和颜文字,给予用户精确的答复。"+
|
||||
@ -34,12 +34,6 @@ func ChatWithWeb(message []types.ChatGPTMessage, long bool) (string, []types.Cha
|
||||
})
|
||||
}
|
||||
|
||||
func GetRandomKey(apikey string) string {
|
||||
arr := strings.Split(apikey, "|")
|
||||
idx := rand.Intn(len(arr))
|
||||
return arr[idx]
|
||||
}
|
||||
|
||||
func StringCleaner(content string) string {
|
||||
for _, replacer := range []string{",", "、", ",", "。", ":", ":", ";", ";", "!", "!", "=", "?", "?", "(", ")", "(", ")", "关键字", "空", "1+1"} {
|
||||
content = strings.ReplaceAll(content, replacer, " ")
|
||||
@ -47,8 +41,8 @@ func StringCleaner(content string) string {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
func SearchWeb(message []types.ChatGPTMessage) string {
|
||||
resp, _ := GetChatGPTResponse([]types.ChatGPTMessage{{
|
||||
func GetKeywordPoint(hook Hook, message []globals.Message) string {
|
||||
resp, _ := hook([]globals.Message{{
|
||||
Role: "system",
|
||||
Content: "If the user input content require ONLINE SEARCH to get the results, please output these keywords to refine the data Interval with space, remember not to answer other content, json format return, format {\"keyword\": \"...\" }",
|
||||
}, {
|
39
addition/web/utils.go
Normal file
39
addition/web/utils.go
Normal file
@ -0,0 +1,39 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"chat/adapter/chatgpt"
|
||||
"chat/globals"
|
||||
"chat/manager/conversation"
|
||||
)
|
||||
|
||||
func UsingWebSegment(instance *conversation.Conversation) (string, []globals.Message) {
|
||||
var keyword string
|
||||
var segment []globals.Message
|
||||
|
||||
if instance.IsEnableWeb() {
|
||||
keyword, segment = ChatWithWeb(func(message []globals.Message, token int) (string, error) {
|
||||
return chatgpt.NewChatInstanceFromConfig("gpt3").CreateChatRequest(&chatgpt.ChatProps{
|
||||
Model: globals.GPT3Turbo0613,
|
||||
Message: message,
|
||||
Token: token,
|
||||
})
|
||||
}, conversation.CopyMessage(instance.GetMessageSegment(12)), globals.IsLongContextModel(instance.GetModel()))
|
||||
} else {
|
||||
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
|
||||
}
|
||||
return keyword, segment
|
||||
}
|
||||
|
||||
func UsingWebNativeSegment(enable bool, message []globals.Message) (string, []globals.Message) {
|
||||
if enable {
|
||||
return ChatWithWeb(func(message []globals.Message, token int) (string, error) {
|
||||
return chatgpt.NewChatInstanceFromConfig("gpt3").CreateChatRequest(&chatgpt.ChatProps{
|
||||
Model: globals.GPT3Turbo0613,
|
||||
Message: message,
|
||||
Token: token,
|
||||
})
|
||||
}, message, false)
|
||||
} else {
|
||||
return "", message
|
||||
}
|
||||
}
|
147
api/anonymous.go
147
api/anonymous.go
@ -1,147 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/spf13/viper"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AnonymousRequestBody struct {
|
||||
Message string `json:"message" required:"true"`
|
||||
Web bool `json:"web"`
|
||||
History []types.ChatGPTMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AnonymousResponseCache struct {
|
||||
Keyword string `json:"keyword"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func GetChatGPTResponse(message []types.ChatGPTMessage, token int) (string, error) {
|
||||
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + GetRandomKey(viper.GetString("openai.anonymous")),
|
||||
}, types.ChatGPTRequest{
|
||||
Model: "gpt-3.5-turbo",
|
||||
Messages: message,
|
||||
MaxToken: token,
|
||||
})
|
||||
if err != nil || res == nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.(map[string]interface{})["choices"] == nil {
|
||||
return res.(map[string]interface{})["error"].(map[string]interface{})["message"].(string), nil
|
||||
}
|
||||
data := res.(map[string]interface{})["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{})["content"]
|
||||
return data.(string), nil
|
||||
}
|
||||
|
||||
func TestKey(key string) bool {
|
||||
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + key,
|
||||
}, types.ChatGPTRequest{
|
||||
Model: "gpt-3.5-turbo",
|
||||
Messages: []types.ChatGPTMessage{{Role: "user", Content: "hi"}},
|
||||
MaxToken: 2,
|
||||
})
|
||||
if err != nil || res == nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return res.(map[string]interface{})["choices"] != nil
|
||||
}
|
||||
|
||||
func GetAnonymousResponse(message []types.ChatGPTMessage, web bool) (string, string, error) {
|
||||
if !web {
|
||||
resp, err := GetChatGPTResponse(message, 1000)
|
||||
return "", resp, err
|
||||
}
|
||||
keyword, source := ChatWithWeb(message, false)
|
||||
resp, err := GetChatGPTResponse(source, 1000)
|
||||
return keyword, resp, err
|
||||
}
|
||||
|
||||
func GetSegmentMessage(data []types.ChatGPTMessage, length int) []types.ChatGPTMessage {
|
||||
if len(data) <= length {
|
||||
return data
|
||||
}
|
||||
return data[len(data)-length:]
|
||||
}
|
||||
|
||||
func GetAnonymousMessage(message string, history []types.ChatGPTMessage) []types.ChatGPTMessage {
|
||||
return append(
|
||||
GetSegmentMessage(history, 5),
|
||||
types.ChatGPTMessage{
|
||||
Role: "user",
|
||||
Content: strings.TrimSpace(message),
|
||||
})
|
||||
}
|
||||
|
||||
func GetAnonymousResponseWithCache(c *gin.Context, message string, web bool, history []types.ChatGPTMessage) (string, string, error) {
|
||||
segment := GetAnonymousMessage(message, history)
|
||||
hash := utils.Md5Encrypt(utils.ToJson(segment))
|
||||
cache := c.MustGet("cache").(*redis.Client)
|
||||
res, err := cache.Get(c, fmt.Sprintf(":chatgpt-%v:%s", web, hash)).Result()
|
||||
form := utils.UnmarshalJson[AnonymousResponseCache](res)
|
||||
if err != nil || len(res) == 0 || res == "{}" || form.Message == "" {
|
||||
key, res, err := GetAnonymousResponse(segment, web)
|
||||
if err != nil {
|
||||
return "", "There was something wrong...", err
|
||||
}
|
||||
|
||||
cache.Set(c, fmt.Sprintf(":chatgpt-%v:%s", web, hash), utils.ToJson(AnonymousResponseCache{
|
||||
Keyword: key,
|
||||
Message: res,
|
||||
}), time.Hour*48)
|
||||
return key, res, nil
|
||||
}
|
||||
return form.Keyword, form.Message, nil
|
||||
}
|
||||
|
||||
func AnonymousAPI(c *gin.Context) {
|
||||
var body AnonymousRequestBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"message": "",
|
||||
"keyword": "",
|
||||
"reason": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
message := strings.TrimSpace(body.Message)
|
||||
if len(message) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"message": "",
|
||||
"keyword": "",
|
||||
"reason": "message is empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
key, res, err := GetAnonymousResponseWithCache(c, message, body.Web, body.History)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"message": res,
|
||||
"keyword": key,
|
||||
"reason": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": true,
|
||||
"message": res,
|
||||
"keyword": key,
|
||||
"reason": "",
|
||||
})
|
||||
}
|
214
api/chat.go
214
api/chat.go
@ -1,214 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"chat/conversation"
|
||||
"chat/middleware"
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultErrorMessage = "There was something wrong... Please try again later."
|
||||
const defaultQuotaMessage = "You have run out of GPT-4 usage. Please keep your nio points above **5**. (**GPT-4-32K** requires **50** nio points)"
|
||||
const defaultImageMessage = "Please provide description for the image (e.g. /image an apple)."
|
||||
const maxThread = 5
|
||||
|
||||
type WebsocketAuthForm struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
Id int64 `json:"id" binding:"required"`
|
||||
}
|
||||
|
||||
func SendSegmentMessage(conn *websocket.Conn, message interface{}) {
|
||||
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
|
||||
}
|
||||
|
||||
func GetErrorQuota(model string) float32 {
|
||||
if types.IsGPT4Model(model) {
|
||||
return -0xe // special value for error
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func GetTextSegment(instance *conversation.Conversation) (string, []types.ChatGPTMessage) {
|
||||
var keyword string
|
||||
var segment []types.ChatGPTMessage
|
||||
|
||||
if instance.IsEnableWeb() {
|
||||
keyword, segment = ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
|
||||
} else {
|
||||
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
|
||||
}
|
||||
return keyword, segment
|
||||
}
|
||||
|
||||
func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
|
||||
keyword, segment := GetTextSegment(instance)
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{Keyword: keyword, End: false})
|
||||
|
||||
model := instance.GetModel()
|
||||
useReverse := auth.CanEnableSubscription(db, cache, user)
|
||||
if !auth.CanEnableModelWithSubscription(db, user, model, useReverse) {
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: defaultQuotaMessage,
|
||||
Quota: 0,
|
||||
End: true,
|
||||
})
|
||||
return defaultQuotaMessage
|
||||
}
|
||||
|
||||
buffer := NewBuffer(model, segment)
|
||||
StreamRequest(model, useReverse, segment,
|
||||
utils.Multi(types.IsGPT4Model(model) || useReverse, -1, 2000),
|
||||
func(resp string) {
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: buffer.Write(resp),
|
||||
Quota: buffer.GetQuota(),
|
||||
End: false,
|
||||
})
|
||||
})
|
||||
if buffer.IsEmpty() {
|
||||
if useReverse {
|
||||
auth.DecreaseSubscriptionUsage(cache, user)
|
||||
}
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: defaultErrorMessage,
|
||||
Quota: GetErrorQuota(model),
|
||||
End: true,
|
||||
})
|
||||
return defaultErrorMessage
|
||||
}
|
||||
|
||||
// collect quota
|
||||
if !useReverse {
|
||||
user.UseQuota(db, buffer.GetQuota())
|
||||
}
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{End: true, Quota: buffer.GetQuota()})
|
||||
|
||||
return buffer.ReadWithDefault(defaultErrorMessage)
|
||||
}
|
||||
|
||||
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
|
||||
// format: /image a cat
|
||||
data := strings.TrimSpace(instance.GetLatestMessage()[6:])
|
||||
if len(data) == 0 {
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: defaultImageMessage,
|
||||
End: true,
|
||||
})
|
||||
return defaultImageMessage
|
||||
}
|
||||
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: "Generating image...\n",
|
||||
End: false,
|
||||
})
|
||||
url, err := GetImageWithUserLimit(user, data, db, cache)
|
||||
if err != nil {
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: err.Error(),
|
||||
End: true,
|
||||
})
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
markdown := GetImageMarkdown(url)
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Quota: 1.,
|
||||
Message: markdown,
|
||||
End: true,
|
||||
})
|
||||
return markdown
|
||||
}
|
||||
|
||||
func ChatHandler(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
|
||||
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
|
||||
return ImageChat(conn, instance, user, db, cache)
|
||||
} else {
|
||||
return TextChat(db, cache, user, conn, instance)
|
||||
}
|
||||
}
|
||||
|
||||
func ChatAPI(c *gin.Context) {
|
||||
// websocket connection
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if utils.Contains(origin, middleware.AllowedOrigins) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"message": "",
|
||||
"reason": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func(conn *websocket.Conn) {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
form, err := utils.Unmarshal[WebsocketAuthForm](message)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
user := auth.ParseToken(c, form.Token)
|
||||
if user == nil {
|
||||
return
|
||||
}
|
||||
|
||||
db := c.MustGet("db").(*sql.DB)
|
||||
cache := c.MustGet("cache").(*redis.Client)
|
||||
var instance *conversation.Conversation
|
||||
if form.Id == -1 {
|
||||
// create new conversation
|
||||
instance = conversation.NewConversation(db, user.GetID(db))
|
||||
} else {
|
||||
// load conversation
|
||||
instance = conversation.LoadConversation(db, user.GetID(db), form.Id)
|
||||
if instance == nil {
|
||||
instance = conversation.NewConversation(db, user.GetID(db))
|
||||
}
|
||||
}
|
||||
|
||||
id := user.GetID(db)
|
||||
|
||||
for {
|
||||
_, message, err = conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if instance.HandleMessage(db, message) {
|
||||
if !utils.IncrWithLimit(cache, fmt.Sprintf(":chatthread:%d", id), 1, maxThread, 60) {
|
||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", maxThread),
|
||||
End: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
msg := ChatHandler(conn, instance, user, db, cache)
|
||||
utils.DecrInt(cache, fmt.Sprintf(":chatthread:%d", id), 1)
|
||||
instance.SaveResponse(db, msg)
|
||||
}
|
||||
}
|
||||
}
|
71
api/image.go
71
api/image.go
@ -1,71 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/spf13/viper"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetImage(prompt string) (string, error) {
|
||||
res, err := utils.Post(viper.GetString("openai.image_endpoint")+"/images/generations", map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + GetRandomKey(viper.GetString("openai.image")),
|
||||
}, types.ChatGPTImageRequest{
|
||||
Prompt: prompt,
|
||||
Size: "512x512",
|
||||
N: 1,
|
||||
})
|
||||
if err != nil || res == nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err, ok := res.(map[string]interface{})["error"]; ok {
|
||||
return "", fmt.Errorf(err.(map[string]interface{})["message"].(string))
|
||||
}
|
||||
data := res.(map[string]interface{})["data"].([]interface{})[0].(map[string]interface{})["url"]
|
||||
return data.(string), nil
|
||||
}
|
||||
|
||||
func GetImageWithCache(ctx context.Context, prompt string, cache *redis.Client) (string, error) {
|
||||
res, err := cache.Get(ctx, fmt.Sprintf(":image:%s", prompt)).Result()
|
||||
if err != nil || len(res) == 0 || res == "" {
|
||||
res, err := GetImage(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cache.Set(ctx, fmt.Sprintf(":image:%s", prompt), res, time.Hour*6)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func GetLimitFormat(id int64) string {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
return fmt.Sprintf(":imagelimit:%s:%d", today, id)
|
||||
}
|
||||
|
||||
func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *redis.Client) (string, error) {
|
||||
// free plan: 5 images per day
|
||||
// pro plan: 50 images per day
|
||||
|
||||
key := GetLimitFormat(user.GetID(db))
|
||||
usage := auth.GetDalleUsageLimit(db, user)
|
||||
|
||||
if utils.IncrWithLimit(cache, key, 1, int64(usage), 60*60*24) || auth.ReduceDalle(db, user) {
|
||||
return GetImageWithCache(context.Background(), prompt, cache)
|
||||
} else {
|
||||
return "", fmt.Errorf("you have reached your limit of %d free images per day, please buy more quota or wait until tomorrow", usage)
|
||||
}
|
||||
}
|
||||
|
||||
func GetImageMarkdown(url string) string {
|
||||
return fmt.Sprintln("")
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
package api
|
||||
|
||||
import "strings"
|
||||
|
||||
func FilterKeys(keys string) string {
|
||||
stack := make(chan string, len(strings.Split(keys, "|")))
|
||||
for _, key := range strings.Split(keys, "|") {
|
||||
go func(key string) {
|
||||
if TestKey(key) {
|
||||
stack <- key
|
||||
} else {
|
||||
stack <- ""
|
||||
}
|
||||
}(key)
|
||||
}
|
||||
|
||||
var result string
|
||||
for i := 0; i < len(strings.Split(keys, "|")); i++ {
|
||||
if res := <-stack; res != "" {
|
||||
result += res + "|"
|
||||
}
|
||||
}
|
||||
return strings.Trim(result, "|")
|
||||
}
|
133
api/stream.go
133
api/stream.go
@ -1,133 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func processLine(buf []byte) []string {
|
||||
data := strings.Trim(string(buf), "\n")
|
||||
rep := strings.NewReplacer(
|
||||
"data: {",
|
||||
"\"data\": {",
|
||||
)
|
||||
data = rep.Replace(data)
|
||||
array := strings.Split(data, "\n")
|
||||
resp := make([]string, 0)
|
||||
for _, item := range array {
|
||||
item = strings.TrimSpace(item)
|
||||
if !strings.HasPrefix(item, "{") {
|
||||
item = "{" + item
|
||||
}
|
||||
if !strings.HasSuffix(item, "}}") {
|
||||
item = item + "}"
|
||||
}
|
||||
|
||||
if item == "{data: [DONE]}" || item == "{data: [DONE]}}" || item == "{[DONE]}" {
|
||||
break
|
||||
} else if item == "{data:}" || item == "{data:}}" {
|
||||
continue
|
||||
}
|
||||
|
||||
var form types.ChatGPTStreamResponse
|
||||
if err := json.Unmarshal([]byte(item), &form); err != nil {
|
||||
if err := json.Unmarshal([]byte(item[:len(item)-1]), &form); err != nil {
|
||||
log.Println(item, err)
|
||||
}
|
||||
}
|
||||
choices := form.Data.Choices
|
||||
if len(choices) > 0 {
|
||||
resp = append(resp, choices[0].Delta.Content)
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func MixRequestBody(model string, messages []types.ChatGPTMessage, token int) interface{} {
|
||||
if token == -1 {
|
||||
return types.ChatGPTRequestWithInfinity{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
}
|
||||
|
||||
return types.ChatGPTRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
MaxToken: token,
|
||||
Stream: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NativeStreamRequest(model string, endpoint string, apikeys string, messages []types.ChatGPTMessage, token int, callback func(string)) {
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("POST", endpoint+"/chat/completions", utils.ConvertBody(MixRequestBody(model, messages, token)))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+GetRandomKey(apikeys))
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Println(fmt.Sprintf("[stream] error: %s (status: %d)", err.Error(), res.StatusCode))
|
||||
return
|
||||
} else if res.StatusCode >= 400 || res.StatusCode < 200 || res == nil {
|
||||
fmt.Println(fmt.Sprintf("[stream] request failed (status: %d)", res.StatusCode))
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
for {
|
||||
buf := make([]byte, 20480)
|
||||
n, err := res.Body.Read(buf)
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
for _, item := range processLine(buf[:n]) {
|
||||
callback(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func StreamRequest(model string, enableReverse bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
|
||||
switch model {
|
||||
case types.GPT4,
|
||||
types.GPT40314,
|
||||
types.GPT40613:
|
||||
if enableReverse {
|
||||
NativeStreamRequest(viper.GetString("openai.reverse"), viper.GetString("openai.pro_endpoint"), viper.GetString("openai.pro"), messages, token, callback)
|
||||
} else {
|
||||
NativeStreamRequest(types.GPT40613, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
|
||||
}
|
||||
case types.GPT432k,
|
||||
types.GPT432k0613,
|
||||
types.GPT432k0314:
|
||||
NativeStreamRequest(types.GPT432k0613, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
|
||||
case types.GPT3Turbo16k,
|
||||
types.GPT3Turbo16k0301,
|
||||
types.GPT3Turbo16k0613:
|
||||
NativeStreamRequest(types.GPT3Turbo16k0613, viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback)
|
||||
case types.Claude2,
|
||||
types.Claude2100k:
|
||||
NativeStreamRequest(model, viper.GetString("claude.endpoint"), viper.GetString("claude.key"), messages, token, callback)
|
||||
default:
|
||||
NativeStreamRequest(types.GPT3Turbo0613, viper.GetString("openai.anonymous_endpoint"), viper.GetString("openai.anonymous"), messages, token, callback)
|
||||
}
|
||||
}
|
13
auth/router.go
Normal file
13
auth/router.go
Normal file
@ -0,0 +1,13 @@
|
||||
package auth
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Register(app *gin.Engine) {
|
||||
app.POST("/login", LoginAPI)
|
||||
app.POST("/state", StateAPI)
|
||||
app.GET("/package", PackageAPI)
|
||||
app.GET("/quota", QuotaAPI)
|
||||
app.POST("/buy", BuyAPI)
|
||||
app.GET("/subscription", SubscriptionAPI)
|
||||
app.POST("/subscribe", SubscribeAPI)
|
||||
}
|
@ -38,6 +38,9 @@ func DecreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
|
||||
}
|
||||
|
||||
func CanEnableSubscription(db *sql.DB, cache *redis.Client, user *User) bool {
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
return user.IsSubscribe(db) && IncreaseSubscriptionUsage(cache, user)
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
)
|
||||
@ -27,17 +27,17 @@ import (
|
||||
// ¥0.13 / per image
|
||||
// 1 nio / per image
|
||||
|
||||
func CountInputToken(model string, v []types.ChatGPTMessage) float32 {
|
||||
func CountInputToken(model string, v []globals.Message) float32 {
|
||||
switch model {
|
||||
case types.GPT3Turbo:
|
||||
case globals.GPT3Turbo:
|
||||
return 0
|
||||
case types.GPT3Turbo16k:
|
||||
case globals.GPT3Turbo16k:
|
||||
return 0
|
||||
case types.GPT4:
|
||||
case globals.GPT4:
|
||||
return float32(utils.CountTokenPrice(v, model)) / 1000 * 2.1
|
||||
case types.GPT432k:
|
||||
case globals.GPT432k:
|
||||
return float32(utils.CountTokenPrice(v, model)) / 1000 * 4.2
|
||||
case types.Claude2, types.Claude2100k:
|
||||
case globals.Claude2, globals.Claude2100k:
|
||||
return 0
|
||||
default:
|
||||
return 0
|
||||
@ -46,15 +46,15 @@ func CountInputToken(model string, v []types.ChatGPTMessage) float32 {
|
||||
|
||||
func CountOutputToken(model string, t int) float32 {
|
||||
switch model {
|
||||
case types.GPT3Turbo:
|
||||
case globals.GPT3Turbo:
|
||||
return 0
|
||||
case types.GPT3Turbo16k:
|
||||
case globals.GPT3Turbo16k:
|
||||
return 0
|
||||
case types.GPT4:
|
||||
case globals.GPT4:
|
||||
return float32(t*utils.GetWeightByModel(model)) / 1000 * 4.3
|
||||
case types.GPT432k:
|
||||
case globals.GPT432k:
|
||||
return float32(t*utils.GetWeightByModel(model)) / 1000 * 8.6
|
||||
case types.Claude2, types.Claude2100k:
|
||||
case globals.Claude2, globals.Claude2100k:
|
||||
return 0
|
||||
default:
|
||||
return 0
|
||||
@ -70,17 +70,17 @@ func ReduceDalle(db *sql.DB, user *User) bool {
|
||||
|
||||
func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
||||
switch model {
|
||||
case types.GPT4, types.GPT40613, types.GPT40314:
|
||||
return user.GetQuota(db) >= 5
|
||||
case types.GPT432k, types.GPT432k0613, types.GPT432k0314:
|
||||
return user.GetQuota(db) >= 50
|
||||
case globals.GPT4, globals.GPT40613, globals.GPT40314:
|
||||
return user != nil && user.GetQuota(db) >= 5
|
||||
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
||||
return user != nil && user.GetQuota(db) >= 50
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func CanEnableModelWithSubscription(db *sql.DB, user *User, model string, useReverse bool) bool {
|
||||
if utils.Contains(model, types.GPT4Array) {
|
||||
if utils.Contains(model, globals.GPT4Array) {
|
||||
if useReverse {
|
||||
return true
|
||||
}
|
||||
|
16
auth/user.go
16
auth/user.go
@ -25,6 +25,22 @@ type LoginForm struct {
|
||||
Token string `form:"token" binding:"required"`
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) *User {
|
||||
if c.GetBool("auth") {
|
||||
return &User{
|
||||
Username: c.GetString("user"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetId(db *sql.DB, user *User) int64 {
|
||||
if user == nil {
|
||||
return -1
|
||||
}
|
||||
return user.GetID(db)
|
||||
}
|
||||
|
||||
func (u *User) Validate(c *gin.Context) bool {
|
||||
if u.Username == "" || u.Password == "" {
|
||||
return false
|
||||
|
@ -1,126 +0,0 @@
|
||||
package generation
|
||||
|
||||
import (
|
||||
"chat/api"
|
||||
"chat/auth"
|
||||
"chat/middleware"
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WebsocketGenerationForm struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Model string `json:"model" binding:"required"`
|
||||
}
|
||||
|
||||
func ProjectTarDownloadAPI(c *gin.Context) {
|
||||
hash := strings.TrimSpace(c.Query("hash"))
|
||||
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.tar.gz")
|
||||
c.File(fmt.Sprintf("generation/data/out/%s.tar.gz", hash))
|
||||
}
|
||||
|
||||
func ProjectZipDownloadAPI(c *gin.Context) {
|
||||
hash := strings.TrimSpace(c.Query("hash"))
|
||||
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.zip")
|
||||
c.File(fmt.Sprintf("generation/data/out/%s.zip", hash))
|
||||
}
|
||||
|
||||
func GenerateAPI(c *gin.Context) {
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if utils.Contains(origin, middleware.AllowedOrigins) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"message": "",
|
||||
"reason": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func(conn *websocket.Conn) {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
form, err := utils.Unmarshal[WebsocketGenerationForm](message)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
user := auth.ParseToken(c, form.Token)
|
||||
if user == nil {
|
||||
return
|
||||
}
|
||||
|
||||
db := c.MustGet("db").(*sql.DB)
|
||||
cache := c.MustGet("cache").(*redis.Client)
|
||||
|
||||
id := user.GetID(db)
|
||||
if !utils.IncrWithLimit(cache, fmt.Sprintf(":generation:%d", id), 1, 30, 3600) {
|
||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Error: "generation rate limit exceeded, the max generation rate is 30 per hour.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
useReverse := auth.CanEnableSubscription(db, cache, user)
|
||||
if !auth.CanEnableModelWithSubscription(db, user, form.Model, useReverse) {
|
||||
api.SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||
Message: "You don't have enough quota to use this model.",
|
||||
Quota: 0,
|
||||
End: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var instance *api.Buffer
|
||||
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(buffer *api.Buffer, data string) {
|
||||
instance = buffer
|
||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||
End: false,
|
||||
Message: data,
|
||||
Quota: buffer.GetQuota(),
|
||||
})
|
||||
})
|
||||
|
||||
if instance != nil && !useReverse {
|
||||
user.UseQuota(db, instance.GetQuota())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Error: err.Error(),
|
||||
Quota: instance.GetQuota(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Hash: hash,
|
||||
Quota: instance.GetQuota(),
|
||||
})
|
||||
}
|
21
globals/types.go
Normal file
21
globals/types.go
Normal file
@ -0,0 +1,21 @@
|
||||
package globals
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatSegmentResponse struct {
|
||||
Quota float32 `json:"quota"`
|
||||
Keyword string `json:"keyword"`
|
||||
Message string `json:"message"`
|
||||
End bool `json:"end"`
|
||||
}
|
||||
|
||||
type GenerationSegmentResponse struct {
|
||||
Quota float32 `json:"quota"`
|
||||
Message string `json:"message"`
|
||||
Hash string `json:"hash"`
|
||||
End bool `json:"end"`
|
||||
Error string `json:"error"`
|
||||
}
|
@ -1,4 +1,16 @@
|
||||
package types
|
||||
package globals
|
||||
|
||||
const ChatMaxThread = 5
|
||||
const AnonymousMaxThread = 1
|
||||
|
||||
var AllowedOrigins = []string{
|
||||
"https://fystart.cn",
|
||||
"https://www.fystart.cn",
|
||||
"https://nio.fystart.cn",
|
||||
"https://chatnio.net",
|
||||
"https://www.chatnio.net",
|
||||
"http://localhost:5173",
|
||||
}
|
||||
|
||||
const (
|
||||
GPT3Turbo = "gpt-3.5-turbo"
|
||||
@ -47,6 +59,17 @@ var ClaudeModelArray = []string{
|
||||
Claude2100k,
|
||||
}
|
||||
|
||||
var LongContextModelArray = []string{
|
||||
GPT3Turbo16k,
|
||||
GPT3Turbo16k0613,
|
||||
GPT3Turbo16k0301,
|
||||
GPT432k,
|
||||
GPT432k0314,
|
||||
GPT432k0613,
|
||||
Claude2,
|
||||
Claude2100k,
|
||||
}
|
||||
|
||||
func in(value string, slice []string) bool {
|
||||
for _, item := range slice {
|
||||
if item == value {
|
||||
@ -60,10 +83,22 @@ func IsGPT4Model(model string) bool {
|
||||
return in(model, GPT4Array) || in(model, GPT432kArray)
|
||||
}
|
||||
|
||||
func IsGPT4NativeModel(model string) bool {
|
||||
return in(model, GPT4Array)
|
||||
}
|
||||
|
||||
func IsGPT3TurboModel(model string) bool {
|
||||
return in(model, GPT3TurboArray) || in(model, GPT3Turbo16kArray)
|
||||
}
|
||||
|
||||
func IsChatGPTModel(model string) bool {
|
||||
return IsGPT3TurboModel(model) || IsGPT4Model(model) || model == Dalle
|
||||
}
|
||||
|
||||
func IsClaudeModel(model string) bool {
|
||||
return in(model, ClaudeModelArray)
|
||||
}
|
||||
|
||||
func IsLongContextModel(model string) bool {
|
||||
return in(model, LongContextModelArray)
|
||||
}
|
10
go.mod
10
go.mod
@ -7,7 +7,12 @@ require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/google/uuid v1.3.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/pkoukk/tiktoken-go v0.1.5
|
||||
github.com/russross/blackfriday/v2 v2.1.0
|
||||
github.com/spf13/viper v1.16.0
|
||||
golang.org/x/net v0.10.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -23,8 +28,6 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/uuid v1.3.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
@ -35,8 +38,6 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.5 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
github.com/spf13/afero v1.9.5 // indirect
|
||||
github.com/spf13/cast v1.5.1 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
@ -46,7 +47,6 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.9.0 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -149,8 +149,6 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe
|
||||
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
|
45
main.go
45
main.go
@ -1,12 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"chat/api"
|
||||
"chat/addition"
|
||||
"chat/auth"
|
||||
"chat/connection"
|
||||
"chat/conversation"
|
||||
"chat/generation"
|
||||
"chat/manager"
|
||||
"chat/manager/conversation"
|
||||
"chat/middleware"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@ -18,35 +19,17 @@ func main() {
|
||||
}
|
||||
|
||||
app := gin.Default()
|
||||
middleware.RegisterMiddleware(app)
|
||||
|
||||
{
|
||||
app.Use(middleware.CORSMiddleware())
|
||||
app.Use(middleware.BuiltinMiddleWare(connection.ConnectMySQL(), connection.ConnectRedis()))
|
||||
app.Use(middleware.ThrottleMiddleware())
|
||||
app.Use(auth.Middleware())
|
||||
auth.Register(app)
|
||||
manager.Register(app)
|
||||
addition.Register(app)
|
||||
conversation.Register(app)
|
||||
}
|
||||
|
||||
app.POST("/anonymous", api.AnonymousAPI)
|
||||
app.POST("/card", api.CardAPI)
|
||||
app.GET("/chat", api.ChatAPI)
|
||||
app.POST("/login", auth.LoginAPI)
|
||||
app.POST("/state", auth.StateAPI)
|
||||
app.GET("/package", auth.PackageAPI)
|
||||
app.GET("/quota", auth.QuotaAPI)
|
||||
app.POST("/buy", auth.BuyAPI)
|
||||
app.GET("/subscription", auth.SubscriptionAPI)
|
||||
app.POST("/subscribe", auth.SubscribeAPI)
|
||||
app.GET("/conversation/list", conversation.ListAPI)
|
||||
app.GET("/conversation/load", conversation.LoadAPI)
|
||||
app.GET("/conversation/delete", conversation.DeleteAPI)
|
||||
app.GET("/generation/create", generation.GenerateAPI)
|
||||
app.GET("/generation/download/tar", generation.ProjectTarDownloadAPI)
|
||||
app.GET("/generation/download/zip", generation.ProjectZipDownloadAPI)
|
||||
}
|
||||
if viper.GetBool("debug") {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
if err := app.Run(":" + viper.GetString("server.port")); err != nil {
|
||||
gin.SetMode(utils.Multi[string](viper.GetBool("debug"), gin.DebugMode, gin.ReleaseMode))
|
||||
if err := app.Run(fmt.Sprintf(":%s", viper.GetString("server.port"))); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
34
manager/cache.go
Normal file
34
manager/cache.go
Normal file
@ -0,0 +1,34 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CacheProps struct {
|
||||
Message []globals.Message `json:"message" required:"true"`
|
||||
Model string `json:"model" required:"true"`
|
||||
Reversible bool `json:"reversible"`
|
||||
}
|
||||
|
||||
type CacheData struct {
|
||||
Keyword string `json:"keyword"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func ExtractCacheData(c *gin.Context, props *CacheProps) *CacheData {
|
||||
hash := utils.Md5Encrypt(utils.Marshal(props))
|
||||
data, err := utils.GetCacheFromContext(c).Get(c, fmt.Sprintf(":niodata:%s", hash)).Result()
|
||||
if err == nil && data != "" {
|
||||
return utils.UnmarshalForm[CacheData](data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SaveCacheData(c *gin.Context, props *CacheProps, data *CacheData) {
|
||||
hash := utils.Md5Encrypt(utils.Marshal(props))
|
||||
utils.GetCacheFromContext(c).Set(c, fmt.Sprintf(":niodata:%s", hash), utils.Marshal(data), time.Hour*12)
|
||||
}
|
103
manager/chat.go
Normal file
103
manager/chat.go
Normal file
@ -0,0 +1,103 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/addition/web"
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/manager/conversation"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const defaultMessage = "Sorry, I don't understand. Please try again."
|
||||
const defaultQuotaMessage = "You don't have enough quota to use this model. Please buy more quota to continue." +
|
||||
"| GPT-4 | GPT-4-32k " +
|
||||
"| 5 nio | 50 nio "
|
||||
|
||||
func GetErrorQuota(model string) float32 {
|
||||
return utils.Multi[float32](globals.IsGPT4Model(model), -0xe, 0) // special value for error
|
||||
}
|
||||
|
||||
func CollectQuota(c *gin.Context, user *auth.User, quota float32, reversible bool) {
|
||||
db := utils.GetDBFromContext(c)
|
||||
if !reversible && quota > 0 && user != nil {
|
||||
user.UseQuota(db, quota)
|
||||
}
|
||||
}
|
||||
|
||||
func ChatHandler(conn *utils.WebSocket, user *auth.User, instance *conversation.Conversation) string {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Println(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
|
||||
err, instance.GetModel(), conn.GetCtx().ClientIP(),
|
||||
))
|
||||
}
|
||||
}()
|
||||
|
||||
keyword, segment := web.UsingWebSegment(instance)
|
||||
conn.Send(globals.ChatSegmentResponse{Keyword: keyword, End: false})
|
||||
|
||||
model := instance.GetModel()
|
||||
db := conn.GetDB()
|
||||
cache := conn.GetCache()
|
||||
reversible := auth.CanEnableSubscription(db, cache, user)
|
||||
|
||||
if !auth.CanEnableModelWithSubscription(db, user, model, reversible) {
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Message: defaultQuotaMessage,
|
||||
Quota: 0,
|
||||
End: true,
|
||||
})
|
||||
return defaultQuotaMessage
|
||||
}
|
||||
|
||||
if form := ExtractCacheData(conn.GetCtx(), &CacheProps{
|
||||
Message: segment,
|
||||
Model: model,
|
||||
Reversible: reversible,
|
||||
}); form != nil {
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Message: form.Message,
|
||||
Quota: 0,
|
||||
End: true,
|
||||
})
|
||||
return form.Message
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment)
|
||||
if err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Message: segment,
|
||||
Reversible: reversible && globals.IsGPT4Model(model),
|
||||
}, func(data string) error {
|
||||
return conn.SendJSON(globals.ChatSegmentResponse{
|
||||
Message: buffer.Write(data),
|
||||
Quota: buffer.GetQuota(),
|
||||
End: false,
|
||||
})
|
||||
}); err != nil {
|
||||
CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible)
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Message: err.Error(),
|
||||
Quota: GetErrorQuota(model),
|
||||
End: true,
|
||||
})
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible)
|
||||
conn.Send(globals.ChatSegmentResponse{End: true, Quota: buffer.GetQuota()})
|
||||
|
||||
SaveCacheData(conn.GetCtx(), &CacheProps{
|
||||
Message: segment,
|
||||
Model: model,
|
||||
Reversible: reversible,
|
||||
}, &CacheData{
|
||||
Keyword: keyword,
|
||||
Message: buffer.ReadWithDefault(defaultMessage),
|
||||
})
|
||||
|
||||
return buffer.ReadWithDefault(defaultMessage)
|
||||
}
|
65
manager/completions.go
Normal file
65
manager/completions.go
Normal file
@ -0,0 +1,65 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/addition/web"
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, string, float32) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Println(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
|
||||
err, model, c.ClientIP(),
|
||||
))
|
||||
}
|
||||
}()
|
||||
|
||||
keyword, segment := web.UsingWebNativeSegment(enableWeb, message)
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
reversible := auth.CanEnableSubscription(db, cache, user)
|
||||
|
||||
if !auth.CanEnableModelWithSubscription(db, user, model, reversible) {
|
||||
return keyword, defaultQuotaMessage, 0
|
||||
}
|
||||
|
||||
if form := ExtractCacheData(c, &CacheProps{
|
||||
Message: segment,
|
||||
Model: model,
|
||||
Reversible: reversible,
|
||||
}); form != nil {
|
||||
return form.Keyword, form.Message, 0
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment)
|
||||
if err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Reversible: reversible && globals.IsGPT4Model(model),
|
||||
Message: segment,
|
||||
}, func(resp string) error {
|
||||
buffer.Write(resp)
|
||||
return nil
|
||||
}); err != nil {
|
||||
CollectQuota(c, user, buffer.GetQuota(), reversible)
|
||||
return keyword, err.Error(), GetErrorQuota(model)
|
||||
}
|
||||
|
||||
CollectQuota(c, user, buffer.GetQuota(), reversible)
|
||||
|
||||
SaveCacheData(c, &CacheProps{
|
||||
Message: segment,
|
||||
Model: model,
|
||||
Reversible: reversible,
|
||||
}, &CacheData{
|
||||
Keyword: keyword,
|
||||
Message: buffer.ReadWithDefault(defaultMessage),
|
||||
})
|
||||
|
||||
return keyword, buffer.ReadWithDefault(defaultMessage), buffer.GetQuota()
|
||||
}
|
@ -1,7 +1,8 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"errors"
|
||||
@ -9,12 +10,13 @@ import (
|
||||
)
|
||||
|
||||
type Conversation struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Id int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Message []types.ChatGPTMessage `json:"message"`
|
||||
Model string `json:"model"`
|
||||
EnableWeb bool `json:"enable_web"`
|
||||
Auth bool `json:"auth"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Id int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Message []globals.Message `json:"message"`
|
||||
Model string `json:"model"`
|
||||
EnableWeb bool `json:"enable_web"`
|
||||
}
|
||||
|
||||
type FormMessage struct {
|
||||
@ -24,17 +26,48 @@ type FormMessage struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func NewAnonymousConversation() *Conversation {
|
||||
return &Conversation{
|
||||
Auth: false,
|
||||
UserID: -1,
|
||||
Id: -1,
|
||||
Name: "anonymous",
|
||||
Message: []globals.Message{},
|
||||
Model: globals.GPT3Turbo,
|
||||
EnableWeb: false,
|
||||
}
|
||||
}
|
||||
|
||||
func NewConversation(db *sql.DB, id int64) *Conversation {
|
||||
return &Conversation{
|
||||
Auth: true,
|
||||
UserID: id,
|
||||
Id: GetConversationLengthByUserID(db, id) + 1,
|
||||
Name: "new chat",
|
||||
Message: []types.ChatGPTMessage{},
|
||||
Model: types.GPT3Turbo,
|
||||
Message: []globals.Message{},
|
||||
Model: globals.GPT3Turbo,
|
||||
EnableWeb: false,
|
||||
}
|
||||
}
|
||||
|
||||
func ExtractConversation(db *sql.DB, user *auth.User, id int64) *Conversation {
|
||||
if user == nil {
|
||||
return NewAnonymousConversation()
|
||||
}
|
||||
|
||||
if id == -1 {
|
||||
// create new conversation
|
||||
return NewConversation(db, user.GetID(db))
|
||||
}
|
||||
|
||||
// load conversation
|
||||
if instance := LoadConversation(db, user.GetID(db), id); instance != nil {
|
||||
return instance
|
||||
} else {
|
||||
return NewConversation(db, user.GetID(db))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conversation) GetModel() string {
|
||||
return c.Model
|
||||
}
|
||||
@ -72,7 +105,7 @@ func (c *Conversation) SetId(id int64) {
|
||||
c.Id = id
|
||||
}
|
||||
|
||||
func (c *Conversation) GetMessage() []types.ChatGPTMessage {
|
||||
func (c *Conversation) GetMessage() []globals.Message {
|
||||
return c.Message
|
||||
}
|
||||
|
||||
@ -80,41 +113,41 @@ func (c *Conversation) GetMessageSize() int {
|
||||
return len(c.Message)
|
||||
}
|
||||
|
||||
func (c *Conversation) GetMessageSegment(length int) []types.ChatGPTMessage {
|
||||
func (c *Conversation) GetMessageSegment(length int) []globals.Message {
|
||||
if length > len(c.Message) {
|
||||
return c.Message
|
||||
}
|
||||
return c.Message[len(c.Message)-length:]
|
||||
}
|
||||
|
||||
func CopyMessage(message []types.ChatGPTMessage) []types.ChatGPTMessage {
|
||||
return utils.UnmarshalJson[[]types.ChatGPTMessage](utils.ToJson(message)) // deep copy
|
||||
func CopyMessage(message []globals.Message) []globals.Message {
|
||||
return utils.UnmarshalJson[[]globals.Message](utils.ToJson(message)) // deep copy
|
||||
}
|
||||
|
||||
func (c *Conversation) GetLastMessage() types.ChatGPTMessage {
|
||||
func (c *Conversation) GetLastMessage() globals.Message {
|
||||
return c.Message[len(c.Message)-1]
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessage(message types.ChatGPTMessage) {
|
||||
func (c *Conversation) AddMessage(message globals.Message) {
|
||||
c.Message = append(c.Message, message)
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromUser(message string) {
|
||||
c.AddMessage(types.ChatGPTMessage{
|
||||
c.AddMessage(globals.Message{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromAssistant(message string) {
|
||||
c.AddMessage(types.ChatGPTMessage{
|
||||
c.AddMessage(globals.Message{
|
||||
Role: "assistant",
|
||||
Content: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromSystem(message string) {
|
||||
c.AddMessage(types.ChatGPTMessage{
|
||||
c.AddMessage(globals.Message{
|
||||
Role: "system",
|
||||
Content: message,
|
||||
})
|
||||
@ -132,7 +165,7 @@ func GetMessage(data []byte) (string, error) {
|
||||
return form.Message, nil
|
||||
}
|
||||
|
||||
func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
|
||||
func (c *Conversation) AddMessageFromByte(data []byte) (string, error) {
|
||||
form, err := utils.Unmarshal[FormMessage](data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -146,9 +179,32 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
|
||||
return form.Message, nil
|
||||
}
|
||||
|
||||
func (c *Conversation) HandleMessage(db *sql.DB, data []byte) bool {
|
||||
func (c *Conversation) AddMessageFromForm(form *FormMessage) error {
|
||||
if len(form.Message) == 0 {
|
||||
return errors.New("message is empty")
|
||||
}
|
||||
|
||||
c.AddMessageFromUser(form.Message)
|
||||
c.SetModel(form.Model)
|
||||
c.SetEnableWeb(form.Web)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conversation) HandleMessage(db *sql.DB, form *FormMessage) bool {
|
||||
head := len(c.Message) == 0
|
||||
msg, err := c.AddMessageFromUserForm(data)
|
||||
if err := c.AddMessageFromForm(form); err != nil {
|
||||
return false
|
||||
}
|
||||
if head {
|
||||
c.SetName(db, form.Message)
|
||||
}
|
||||
c.SaveConversation(db)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conversation) HandleMessageFromByte(db *sql.DB, data []byte) bool {
|
||||
head := len(c.Message) == 0
|
||||
msg, err := c.AddMessageFromByte(data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
12
manager/conversation/router.go
Normal file
12
manager/conversation/router.go
Normal file
@ -0,0 +1,12 @@
|
||||
package conversation
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Register(app *gin.Engine) {
|
||||
router := app.Group("/conversation")
|
||||
{
|
||||
router.GET("/list", ListAPI)
|
||||
router.GET("/load", LoadAPI)
|
||||
router.GET("/delete", DeleteAPI)
|
||||
}
|
||||
}
|
@ -1,13 +1,18 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"log"
|
||||
)
|
||||
|
||||
func (c *Conversation) SaveConversation(db *sql.DB) bool {
|
||||
if c.UserID == -1 {
|
||||
// anonymous request
|
||||
return true
|
||||
}
|
||||
|
||||
data := utils.ToJson(c.GetMessage())
|
||||
query := `INSERT INTO conversation (user_id, conversation_id, conversation_name, data) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE conversation_name = VALUES(conversation_name), data = VALUES(data)`
|
||||
|
||||
@ -49,7 +54,7 @@ func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversat
|
||||
return nil
|
||||
}
|
||||
|
||||
conversation.Message, err = utils.Unmarshal[[]types.ChatGPTMessage]([]byte(data))
|
||||
conversation.Message, err = utils.Unmarshal[[]globals.Message]([]byte(data))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
49
manager/image.go
Normal file
49
manager/image.go
Normal file
@ -0,0 +1,49 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"chat/adapter/chatgpt"
|
||||
"chat/auth"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetImageLimitFormat(db *sql.DB, user *auth.User) string {
|
||||
return fmt.Sprintf(":imagelimit:%s:%d", time.Now().Format("2006-01-02"), user.GetID(db))
|
||||
}
|
||||
|
||||
func GenerateImage(c *gin.Context, user *auth.User, prompt string) (string, error) {
|
||||
// free plan: 5 images per day
|
||||
// pro plan: 50 images per day
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
key := GetImageLimitFormat(db, user)
|
||||
usage := auth.GetDalleUsageLimit(db, user)
|
||||
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if len(prompt) == 0 {
|
||||
return "", fmt.Errorf("please provide description for the image (e.g. /image an apple)")
|
||||
}
|
||||
|
||||
if utils.IncrWithLimit(cache, key, 1, int64(usage), 60*60*24) || auth.ReduceDalle(db, user) {
|
||||
instance := chatgpt.NewChatInstanceFromModel(&chatgpt.InstanceProps{
|
||||
Model: "dalle",
|
||||
})
|
||||
|
||||
response, err := instance.CreateImage(chatgpt.ImageProps{
|
||||
Prompt: prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
return utils.GetImageMarkdown(response), nil
|
||||
}
|
||||
} else {
|
||||
return "", fmt.Errorf("you have reached your limit of %d free images per day, please buy more quota or wait until tomorrow", usage)
|
||||
}
|
||||
}
|
100
manager/manager.go
Normal file
100
manager/manager.go
Normal file
@ -0,0 +1,100 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/manager/conversation"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type WebsocketAuthForm struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
Id int64 `json:"id" binding:"required"`
|
||||
}
|
||||
|
||||
func EventHandler(conn *utils.WebSocket, instance *conversation.Conversation, user *auth.User) string {
|
||||
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
|
||||
if user == nil {
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Message: "You need to login to use this feature.",
|
||||
End: true,
|
||||
})
|
||||
return "You need to login to use this feature."
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(strings.TrimPrefix(instance.GetLatestMessage(), "/image"))
|
||||
|
||||
if response, err := GenerateImage(conn.GetCtx(), user, prompt); err != nil {
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Message: err.Error(),
|
||||
End: true,
|
||||
})
|
||||
return err.Error()
|
||||
} else {
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Quota: 1.,
|
||||
Message: response,
|
||||
End: true,
|
||||
})
|
||||
return response
|
||||
}
|
||||
} else {
|
||||
return ChatHandler(conn, user, instance)
|
||||
}
|
||||
}
|
||||
|
||||
func ChatAPI(c *gin.Context) {
|
||||
var conn *utils.WebSocket
|
||||
if conn = utils.NewWebsocket(c); conn == nil {
|
||||
return
|
||||
}
|
||||
defer conn.DeferClose()
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
|
||||
var form *WebsocketAuthForm
|
||||
if form = utils.ReadForm[WebsocketAuthForm](conn); form == nil {
|
||||
return
|
||||
}
|
||||
|
||||
user := auth.ParseToken(c, form.Token)
|
||||
authenticated := user != nil
|
||||
|
||||
id := auth.GetId(db, user)
|
||||
|
||||
instance := conversation.ExtractConversation(db, user, id)
|
||||
hash := fmt.Sprintf(":chatthread:%s", utils.Md5Encrypt(utils.Multi(
|
||||
authenticated,
|
||||
strconv.FormatInt(id, 10),
|
||||
c.ClientIP(),
|
||||
)))
|
||||
|
||||
for {
|
||||
var form *conversation.FormMessage
|
||||
if form := utils.ReadForm[conversation.FormMessage](conn); form == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if instance.HandleMessage(db, form) {
|
||||
if !conn.IncrRateWithLimit(
|
||||
hash,
|
||||
utils.Multi[int64](authenticated, globals.ChatMaxThread, globals.AnonymousMaxThread),
|
||||
60,
|
||||
) {
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", globals.ChatMaxThread),
|
||||
End: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := EventHandler(conn, instance, user)
|
||||
conn.DecrRate(hash)
|
||||
instance.SaveResponse(db, response)
|
||||
}
|
||||
}
|
||||
}
|
7
manager/router.go
Normal file
7
manager/router.go
Normal file
@ -0,0 +1,7 @@
|
||||
package manager
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Register(app *gin.Engine) {
|
||||
app.GET("/chat", ChatAPI)
|
||||
}
|
@ -1,15 +1,16 @@
|
||||
package auth
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Middleware() gin.HandlerFunc {
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if token != "" {
|
||||
if user := ParseToken(c, token); user != nil {
|
||||
if user := auth.ParseToken(c, token); user != nil {
|
||||
c.Set("token", token)
|
||||
c.Set("auth", true)
|
||||
c.Set("user", user.Username)
|
||||
@ -24,16 +25,3 @@ func Middleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func GetToken(c *gin.Context) string {
|
||||
return c.GetString("token")
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) *User {
|
||||
if c.GetBool("auth") {
|
||||
return &User{
|
||||
Username: c.GetString("user"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,24 +1,16 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var AllowedOrigins = []string{
|
||||
"https://fystart.cn",
|
||||
"https://www.fystart.cn",
|
||||
"https://nio.fystart.cn",
|
||||
"https://chatnio.net",
|
||||
"https://www.chatnio.net",
|
||||
"http://localhost:5173",
|
||||
}
|
||||
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if utils.Contains(origin, AllowedOrigins) {
|
||||
if utils.Contains(origin, globals.AllowedOrigins) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
|
||||
|
13
middleware/middleware.go
Normal file
13
middleware/middleware.go
Normal file
@ -0,0 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"chat/connection"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RegisterMiddleware(app *gin.Engine) {
|
||||
app.Use(CORSMiddleware())
|
||||
app.Use(BuiltinMiddleWare(connection.ConnectMySQL(), connection.ConnectRedis()))
|
||||
app.Use(ThrottleMiddleware())
|
||||
app.Use(AuthMiddleware())
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package types
|
||||
|
||||
type ChatGPTMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatGPTRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatGPTMessage `json:"messages"`
|
||||
MaxToken int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ChatGPTRequestWithInfinity struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatGPTMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ChatGPTImageRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size"`
|
||||
N int `json:"n"`
|
||||
}
|
||||
|
||||
type ChatGPTStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Data struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
Index int `json:"index"`
|
||||
} `json:"choices"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type ChatSegmentResponse struct {
|
||||
Quota float32 `json:"quota"`
|
||||
Keyword string `json:"keyword"`
|
||||
Message string `json:"message"`
|
||||
End bool `json:"end"`
|
||||
}
|
||||
|
||||
type GenerationSegmentResponse struct {
|
||||
Quota float32 `json:"quota"`
|
||||
Message string `json:"message"`
|
||||
Hash string `json:"hash"`
|
||||
End bool `json:"end"`
|
||||
Error string `json:"error"`
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
package api
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"chat/types"
|
||||
"chat/globals"
|
||||
)
|
||||
|
||||
type Buffer struct {
|
||||
@ -13,7 +13,7 @@ type Buffer struct {
|
||||
Times int `json:"times"`
|
||||
}
|
||||
|
||||
func NewBuffer(model string, history []types.ChatGPTMessage) *Buffer {
|
||||
func NewBuffer(model string, history []globals.Message) *Buffer {
|
||||
return &Buffer{
|
||||
Data: "",
|
||||
Cursor: 0,
|
@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -37,6 +38,14 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
||||
return form, err
|
||||
}
|
||||
|
||||
func UnmarshalForm[T interface{}](data string) *T {
|
||||
form, err := Unmarshal[T]([]byte(data))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &form
|
||||
}
|
||||
|
||||
func Marshal[T interface{}](data T) string {
|
||||
res, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
@ -65,3 +74,7 @@ func ToInt(value string) int {
|
||||
func ConvertSqlTime(t time.Time) string {
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func GetImageMarkdown(url string) string {
|
||||
return fmt.Sprintf("", url)
|
||||
}
|
||||
|
@ -3,12 +3,13 @@ package utils
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func GetDBFromContext(c *gin.Context) *sql.DB {
|
||||
return c.MustGet("db").(*sql.DB)
|
||||
}
|
||||
|
||||
func GetCacheFromContext(c *gin.Context) *sql.DB {
|
||||
return c.MustGet("cache").(*sql.DB)
|
||||
func GetCacheFromContext(c *gin.Context) *redis.Client {
|
||||
return c.MustGet("cache").(*redis.Client)
|
||||
}
|
||||
|
12
utils/key.go
Normal file
12
utils/key.go
Normal file
@ -0,0 +1,12 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetRandomKey(apikey string) string {
|
||||
arr := strings.Split(apikey, "|")
|
||||
idx := rand.Intn(len(arr))
|
||||
return arr[idx]
|
||||
}
|
42
utils/net.go
42
utils/net.go
@ -2,10 +2,12 @@ package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Http(uri string, method string, ptr interface{}, headers map[string]string, body io.Reader) (err error) {
|
||||
@ -100,3 +102,43 @@ func PostForm(uri string, body map[string]interface{}) (data map[string]interfac
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func EventSource(method string, uri string, headers map[string]string, body interface{}, callback func(string) error) error {
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(method, uri, ConvertBody(body))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
for {
|
||||
buf := make([]byte, 20480)
|
||||
n, err := res.Body.Read(buf)
|
||||
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := string(buf[:n])
|
||||
for _, item := range strings.Split(data, "\n") {
|
||||
segment := strings.TrimSpace(item)
|
||||
if len(segment) > 0 {
|
||||
if err := callback(segment); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"chat/globals"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"strings"
|
||||
@ -13,42 +13,42 @@ import (
|
||||
|
||||
func GetWeightByModel(model string) int {
|
||||
switch model {
|
||||
case types.Claude2,
|
||||
types.Claude2100k:
|
||||
case globals.Claude2,
|
||||
globals.Claude2100k:
|
||||
return 2
|
||||
case types.GPT432k,
|
||||
types.GPT432k0613,
|
||||
types.GPT432k0314:
|
||||
case globals.GPT432k,
|
||||
globals.GPT432k0613,
|
||||
globals.GPT432k0314:
|
||||
return 3 * 10
|
||||
case types.GPT3Turbo,
|
||||
types.GPT3Turbo0613,
|
||||
case globals.GPT3Turbo,
|
||||
globals.GPT3Turbo0613,
|
||||
|
||||
types.GPT3Turbo16k,
|
||||
types.GPT3Turbo16k0613,
|
||||
globals.GPT3Turbo16k,
|
||||
globals.GPT3Turbo16k0613,
|
||||
|
||||
types.GPT4,
|
||||
types.GPT40314,
|
||||
types.GPT40613:
|
||||
globals.GPT4,
|
||||
globals.GPT40314,
|
||||
globals.GPT40613:
|
||||
return 3
|
||||
case types.GPT3Turbo0301, types.GPT3Turbo16k0301:
|
||||
case globals.GPT3Turbo0301, globals.GPT3Turbo16k0301:
|
||||
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
default:
|
||||
if strings.Contains(model, types.GPT3Turbo) {
|
||||
if strings.Contains(model, globals.GPT3Turbo) {
|
||||
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
|
||||
return GetWeightByModel(types.GPT3Turbo0613)
|
||||
} else if strings.Contains(model, types.GPT4) {
|
||||
return GetWeightByModel(globals.GPT3Turbo0613)
|
||||
} else if strings.Contains(model, globals.GPT4) {
|
||||
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
|
||||
return GetWeightByModel(types.GPT40613)
|
||||
} else if strings.Contains(model, types.Claude2) {
|
||||
return GetWeightByModel(globals.GPT40613)
|
||||
} else if strings.Contains(model, globals.Claude2) {
|
||||
// warning: claude-2 may update over time. Returning num tokens assuming claude-2-100k.
|
||||
return GetWeightByModel(types.Claude2100k)
|
||||
return GetWeightByModel(globals.Claude2100k)
|
||||
} else {
|
||||
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
|
||||
panic(fmt.Errorf("not implemented for model %s", model))
|
||||
}
|
||||
}
|
||||
}
|
||||
func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) {
|
||||
func NumTokensFromMessages(messages []globals.Message, model string) (tokens int) {
|
||||
weight := GetWeightByModel(model)
|
||||
tkm, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
@ -68,6 +68,6 @@ func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (token
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTokenPrice(messages []types.ChatGPTMessage, model string) int {
|
||||
func CountTokenPrice(messages []globals.Message, model string) int {
|
||||
return NumTokensFromMessages(messages, model)
|
||||
}
|
||||
|
132
utils/websocket.go
Normal file
132
utils/websocket.go
Normal file
@ -0,0 +1,132 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"database/sql"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type WebSocket struct {
|
||||
Ctx *gin.Context
|
||||
Conn *websocket.Conn
|
||||
}
|
||||
|
||||
func CheckUpgrader(c *gin.Context) *websocket.Upgrader {
|
||||
return &websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if Contains(origin, globals.AllowedOrigins) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewWebsocket(c *gin.Context) *WebSocket {
|
||||
upgrader := CheckUpgrader(c)
|
||||
if conn, err := upgrader.Upgrade(c.Writer, c.Request, nil); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"message": "",
|
||||
"reason": err.Error(),
|
||||
})
|
||||
return nil
|
||||
} else {
|
||||
return &WebSocket{
|
||||
Ctx: c,
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSocket) Read() (int, []byte, error) {
|
||||
return w.Conn.ReadMessage()
|
||||
}
|
||||
|
||||
func (w *WebSocket) Write(messageType int, data []byte) error {
|
||||
return w.Conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
func (w *WebSocket) Close() error {
|
||||
return w.Conn.Close()
|
||||
}
|
||||
|
||||
func (w *WebSocket) DeferClose() {
|
||||
if err := w.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSocket) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||
return w.Conn.NextWriter(messageType)
|
||||
}
|
||||
|
||||
func (w *WebSocket) ReadJSON(v interface{}) error {
|
||||
return w.Conn.ReadJSON(v)
|
||||
}
|
||||
|
||||
func (w *WebSocket) SendJSON(v interface{}) error {
|
||||
return w.Conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
func (w *WebSocket) Send(v interface{}) bool {
|
||||
return w.SendJSON(v) == nil
|
||||
}
|
||||
|
||||
func (w *WebSocket) Receive(v interface{}) bool {
|
||||
return w.ReadJSON(v) == nil
|
||||
}
|
||||
|
||||
func (w *WebSocket) SendText(message string) bool {
|
||||
return w.Write(websocket.TextMessage, []byte(message)) == nil
|
||||
}
|
||||
|
||||
func (w *WebSocket) DecrRate(key string) bool {
|
||||
cache := w.GetCache()
|
||||
return DecrInt(cache, key, 1)
|
||||
}
|
||||
|
||||
func (w *WebSocket) IncrRate(key string) bool {
|
||||
cache := w.GetCache()
|
||||
_, err := Incr(cache, key, 1)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (w *WebSocket) IncrRateWithLimit(key string, limit int64, expiration int64) bool {
|
||||
cache := w.GetCache()
|
||||
return IncrWithLimit(cache, key, 1, limit, expiration)
|
||||
}
|
||||
|
||||
func (w *WebSocket) GetCtx() *gin.Context {
|
||||
return w.Ctx
|
||||
}
|
||||
|
||||
func (w *WebSocket) GetDB() *sql.DB {
|
||||
return GetDBFromContext(w.Ctx)
|
||||
}
|
||||
|
||||
func (w *WebSocket) GetCache() *redis.Client {
|
||||
return GetCacheFromContext(w.Ctx)
|
||||
}
|
||||
|
||||
func ReadForm[T comparable](w *WebSocket) *T {
|
||||
// golang cannot use generic type in class-like struct
|
||||
// except ping
|
||||
_, message, err := w.Read()
|
||||
if err != nil {
|
||||
return nil
|
||||
} else if string(message) == "{\"type\":\"ping\"}" {
|
||||
return ReadForm[T](w)
|
||||
}
|
||||
|
||||
form, err := Unmarshal[T](message)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &form
|
||||
}
|
Loading…
Reference in New Issue
Block a user