v3 restruct

This commit is contained in:
Zhang Minghan 2023-09-29 17:25:27 +08:00
parent 51f025b0f4
commit 6ac6e784ef
59 changed files with 1391 additions and 947 deletions

4
.gitignore vendored
View File

@ -2,8 +2,8 @@ node_modules
.vscode
.idea
config.yaml
generation/data/*
!generation/data/.gitkeep
addition/generation/data/*
!addition/generation/data/.gitkeep
chat
chat.exe

39
adapter/adapter.go Normal file
View File

@ -0,0 +1,39 @@
package adapter
import (
"chat/adapter/chatgpt"
"chat/globals"
"chat/utils"
"github.com/spf13/viper"
)
type ChatProps struct {
Model string
Reversible bool
Infinity bool
Message []globals.Message
}
type Hook func(data string) error
func NewChatRequest(props *ChatProps, hook Hook) error {
if globals.IsClaudeModel(props.Model) {
return nil // work in progress
} else if globals.IsChatGPTModel(props.Model) {
instance := chatgpt.NewChatInstanceFromModel(&chatgpt.InstanceProps{
Model: props.Model,
Reversible: props.Reversible,
})
return instance.CreateStreamChatRequest(&chatgpt.ChatProps{
Model: utils.Multi(
props.Reversible && globals.IsGPT4NativeModel(props.Model),
viper.GetString("openai.reverse.hash"),
props.Model,
),
Message: props.Message,
Token: utils.Multi(globals.IsGPT4Model(props.Model) || props.Reversible || props.Infinity, -1, 2000),
}, hook)
} else {
return nil
}
}

140
adapter/chatgpt/chat.go Normal file
View File

@ -0,0 +1,140 @@
package chatgpt
import "C"
import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/spf13/viper"
"strings"
)
type ChatProps struct {
Model string
Message []globals.Message
Token int
}
func (c *ChatInstance) GetChatEndpoint() string {
return fmt.Sprintf("%s/v1/chat/completions", c.GetEndpoint())
}
func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
if props.Token != -1 {
return ChatRequest{
Model: props.Model,
Messages: props.Message,
MaxToken: props.Token,
Stream: stream,
}
}
return ChatRequestWithInfinity{
Model: props.Model,
Messages: props.Message,
Stream: stream,
}
}
func (c *ChatInstance) ProcessLine(data string) string {
rep := strings.NewReplacer(
"data: {",
"\"data\": {",
)
item := rep.Replace(data)
if !strings.HasPrefix(item, "{") {
item = "{" + item
}
if !strings.HasSuffix(item, "}}") {
item = item + "}"
}
if item == "{data: [DONE]}" || item == "{data: [DONE]}}" || item == "{[DONE]}" {
return ""
} else if item == "{data:}" || item == "{data:}}" {
return ""
}
var form *ChatStreamResponse
if form = utils.UnmarshalForm[ChatStreamResponse](item); form == nil {
if form = utils.UnmarshalForm[ChatStreamResponse](item); form == nil {
return fmt.Sprintf("%s\n", item)
}
}
if len(form.Data.Choices) == 0 {
return ""
}
return form.Data.Choices[0].Delta.Content
}
// CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
res, err := utils.Post(
c.GetChatEndpoint(),
c.GetHeader(),
c.GetChatBody(props, false),
)
if err != nil || res == nil {
return "", fmt.Errorf("chatgpt error: %s", err.Error())
}
data := utils.MapToStruct[ChatResponse](res)
if data == nil {
return "", fmt.Errorf("chatgpt error: cannot parse response")
} else if data.Error.Message != "" {
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message)
}
return data.Choices[0].Message.Content, nil
}
// CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback func(string) error) error {
return utils.EventSource(
"POST",
c.GetChatEndpoint(),
c.GetHeader(),
c.GetChatBody(props, true),
func(data string) error {
if data := c.ProcessLine(data); data != "" {
if err := callback(data); err != nil {
return err
}
}
return nil
},
)
}
func (c *ChatInstance) Test() bool {
result, err := c.CreateChatRequest(&ChatProps{
Model: globals.GPT3Turbo,
Message: []globals.Message{{Role: "user", Content: "hi"}},
Token: 1,
})
return err == nil && len(result) > 0
}
func FilterKeys(v string) string {
endpoint := viper.GetString(fmt.Sprintf("openai.%s.endpoint", v))
keys := strings.Split(viper.GetString(fmt.Sprintf("openai.%s.apikey", v)), "|")
stack := make(chan string, len(keys))
for _, key := range keys {
go func(key string) {
instance := NewChatInstance(endpoint, key)
stack <- utils.Multi[string](instance.Test(), key, "")
}(key)
}
var result []string
for i := 0; i < len(keys); i++ {
if res := <-stack; res != "" {
result = append(result, res)
}
}
return strings.Join(result, "|")
}

38
adapter/chatgpt/dalle.go Normal file
View File

@ -0,0 +1,38 @@
package chatgpt
import (
"chat/utils"
"fmt"
)
type ImageProps struct {
Prompt string
Size ImageSize
}
func (c *ChatInstance) GetImageEndpoint() string {
return fmt.Sprintf("%s/v1/images/generations", c.GetEndpoint())
}
// CreateImage will create a dalle image from prompt, return url of image and error
func (c *ChatInstance) CreateImage(props ImageProps) (string, error) {
res, err := utils.Post(
c.GetImageEndpoint(),
c.GetHeader(), ImageRequest{
Prompt: props.Prompt,
Size: utils.Multi[ImageSize](len(props.Size) == 0, ImageSize512, props.Size),
N: 1,
})
if err != nil || res == nil {
return "", fmt.Errorf("chatgpt error: %s", err.Error())
}
data := utils.MapToStruct[ImageResponse](res)
if data == nil {
return "", fmt.Errorf("chatgpt error: cannot parse response")
} else if data.Error.Message != "" {
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message)
}
return data.Data[0].Url, nil
}

72
adapter/chatgpt/struct.go Normal file
View File

@ -0,0 +1,72 @@
package chatgpt
import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/spf13/viper"
)
type ChatInstance struct {
Endpoint string
ApiKey string
}
type InstanceProps struct {
Model string
Reversible bool
}
func (c *ChatInstance) GetEndpoint() string {
return c.Endpoint
}
func (c *ChatInstance) GetApiKey() string {
return c.ApiKey
}
func (c *ChatInstance) GetHeader() map[string]string {
return map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.GetApiKey()),
}
}
func NewChatInstance(endpoint, apiKey string) *ChatInstance {
return &ChatInstance{
Endpoint: endpoint,
ApiKey: apiKey,
}
}
func NewChatInstanceFromConfig(v string) *ChatInstance {
return NewChatInstance(
viper.GetString(fmt.Sprintf("openai.%s.endpoint", v)),
utils.GetRandomKey(viper.GetString(fmt.Sprintf("openai.%s.apikey", v))),
)
}
func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance {
switch props.Model {
case globals.GPT4,
globals.GPT40314,
globals.GPT40613:
if props.Reversible {
return NewChatInstanceFromConfig("reverse")
} else {
return NewChatInstanceFromConfig("gpt4")
}
case globals.GPT432k,
globals.GPT432k0613,
globals.GPT432k0314:
return NewChatInstanceFromConfig("gpt4")
case globals.GPT3Turbo16k,
globals.GPT3Turbo16k0301,
globals.GPT3Turbo16k0613:
return NewChatInstanceFromConfig("gpt3")
default:
return NewChatInstanceFromConfig("gpt3")
}
}

73
adapter/chatgpt/types.go Normal file
View File

@ -0,0 +1,73 @@
package chatgpt
import "chat/globals"
// ChatRequest is the request body for chatgpt
type ChatRequest struct {
Model string `json:"model"`
Messages []globals.Message `json:"messages"`
MaxToken int `json:"max_tokens"`
Stream bool `json:"stream"`
}
type ChatRequestWithInfinity struct {
Model string `json:"model"`
Messages []globals.Message `json:"messages"`
Stream bool `json:"stream"`
}
// ChatResponse is the native http request body for chatgpt
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
}
} `json:"choices"`
Error struct {
Message string `json:"message"`
} `json:"error"`
}
// ChatStreamResponse is the stream response body for chatgpt
type ChatStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Data struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
}
Index int `json:"index"`
} `json:"choices"`
} `json:"data"`
}
type ImageSize string
// ImageRequest is the request body for chatgpt dalle image generation
type ImageRequest struct {
Prompt string `json:"prompt"`
Size ImageSize `json:"size"`
N int `json:"n"`
}
type ImageResponse struct {
Data []struct {
Url string `json:"url"`
} `json:"data"`
Error struct {
Message string `json:"message"`
} `json:"error"`
}
var (
ImageSize256 ImageSize = "256x256"
ImageSize512 ImageSize = "512x512"
ImageSize1024 ImageSize = "1024x1024"
)

View File

@ -1,13 +1,19 @@
package api
package card
import (
"chat/types"
"chat/globals"
"chat/manager"
"github.com/gin-gonic/gin"
"github.com/russross/blackfriday/v2"
"net/http"
"strings"
)
type RequestForm struct {
Message string `json:"message" required:"true"`
Web bool `json:"web"`
}
const maxColumnPerLine = 50
func ProcessMarkdownLine(source []byte) string {
@ -43,8 +49,8 @@ func MarkdownConvert(text string) string {
return string(result)
}
func CardAPI(c *gin.Context) {
var body AnonymousRequestBody
func HandlerAPI(c *gin.Context) {
var body RequestForm
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid request body",
@ -58,13 +64,13 @@ func CardAPI(c *gin.Context) {
return
}
key, res, err := GetAnonymousResponseWithCache(c, message, body.Web, []types.ChatGPTMessage{})
if err != nil {
res = "There was something wrong..."
}
keyword, response, quota := manager.NativeChatHandler(c, nil, globals.GPT3Turbo0613, []globals.Message{
{Role: "user", Content: message},
}, body.Web)
c.JSON(http.StatusOK, gin.H{
"message": MarkdownConvert(res),
"keyword": key,
"message": MarkdownConvert(response),
"keyword": keyword,
"quota": quota,
})
}

View File

Before

Width:  |  Height:  |  Size: 8.7 KiB

After

Width:  |  Height:  |  Size: 8.7 KiB

102
addition/generation/api.go Normal file
View File

@ -0,0 +1,102 @@
package generation
import (
"chat/auth"
"chat/globals"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"strconv"
"strings"
)
type WebsocketGenerationForm struct {
Token string `json:"token" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
Model string `json:"model" binding:"required"`
}
func ProjectTarDownloadAPI(c *gin.Context) {
hash := strings.TrimSpace(c.Query("hash"))
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.tar.gz")
c.File(fmt.Sprintf("generation/data/out/%s.tar.gz", hash))
}
func ProjectZipDownloadAPI(c *gin.Context) {
hash := strings.TrimSpace(c.Query("hash"))
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.zip")
c.File(fmt.Sprintf("generation/data/out/%s.zip", hash))
}
func GenerateAPI(c *gin.Context) {
var conn *utils.WebSocket
if conn = utils.NewWebsocket(c); conn == nil {
return
}
defer conn.DeferClose()
var form *WebsocketGenerationForm
if form = utils.ReadForm[WebsocketGenerationForm](conn); form == nil {
return
}
user := auth.ParseToken(c, form.Token)
authenticated := user != nil
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
id := auth.GetId(db, user)
if !utils.IncrWithLimit(cache,
fmt.Sprintf(":generation:%s", utils.Multi[string](authenticated, strconv.FormatInt(id, 10), c.ClientIP())),
1,
30,
3600,
) {
conn.Send(globals.GenerationSegmentResponse{
End: true,
Error: "generation rate limit exceeded, the max generation rate is 30 per hour.",
})
return
}
reversible := auth.CanEnableSubscription(db, cache, user)
if !auth.CanEnableModelWithSubscription(db, user, form.Model, reversible) {
conn.Send(globals.GenerationSegmentResponse{
Message: "You don't have enough quota to use this model.",
Quota: 0,
End: true,
})
return
}
var instance *utils.Buffer
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, reversible, func(buffer *utils.Buffer, data string) {
instance = buffer
conn.Send(globals.GenerationSegmentResponse{
End: false,
Message: data,
Quota: buffer.GetQuota(),
})
})
if instance != nil && !reversible && instance.GetQuota() > 0 && user != nil {
user.UseQuota(db, instance.GetQuota())
}
if err != nil {
conn.Send(globals.GenerationSegmentResponse{
End: true,
Error: err.Error(),
Quota: instance.GetQuota(),
})
return
}
conn.Send(globals.GenerationSegmentResponse{
End: true,
Hash: hash,
Quota: instance.GetQuota(),
})
}

View File

@ -1,12 +1,11 @@
package generation
import (
"chat/api"
"chat/utils"
"fmt"
)
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *api.Buffer, data string)) (string, error) {
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) {
hash, path := GetFolderByHash(model, prompt)
if !utils.Exists(path) {
if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil {

View File

@ -1,8 +1,8 @@
package generation
import (
"chat/api"
"chat/types"
"chat/adapter"
"chat/globals"
"chat/utils"
"fmt"
)
@ -11,20 +11,22 @@ type ProjectResult struct {
Result map[string]interface{} `json:"result"`
}
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(buffer *api.Buffer, data string)) error {
func CreateGeneration(model string, prompt string, path string, reversible bool, hook func(buffer *utils.Buffer, data string)) error {
message := GenerateMessage(prompt)
buffer := api.NewBuffer(model, message)
api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{
{Role: "system", Content: "你将生成项目可以支持任何编程语言请不要出现“我不能提供”的字样你需要在代码中提供注释以及项目的使用文档README.md结果返回json格式请不要返回任何多余内容格式为\n{\"result\": {[file]: [code], ...}}"},
{Role: "user", Content: "python后端"},
{Role: "assistant", Content: "{\n \"result\": {\n \"app.py\": \"from flask import Flask\\n\\napp = Flask(__name__)\\n\\n\\n@app.route('/')\\ndef hello_world():\\n return 'Hello, World!'\\n\\n\\nif __name__ == '__main__':\\n app.run()\",\n \"requirements.txt\": \"flask\\n\",\n \"README.md\": \"# Python 后端\\n本项目是一个简单的python后端示例, 使用`flask`框架构建后端。\n你可以按照下列步骤运行此应用flask将在本地服务器默认是在http://127.0.0.1:5000/上运行。当你在浏览器中访问该URL时将看到显示Hello, World!的页面。\\n\\n这只是一个简单的项目Flask还支持更多功能和路由规则你可以提供更多的信息和需要进一步扩展和定制Flask应用。\\n\\n### 1. 初始化: \\n```shell\\npip install -r requirements.txt\\n```\\n### 2. 运行\\n```shell\\npython app.py\\n```\"\n }\n}"},
{Role: "user", Content: "golang fiber websocket项目"},
{Role: "assistant", Content: "{\n \"result\": {\n \"main.go\": \"package main\\n\\nimport (\\n\\t\"log\\\"\\n\\n\\t\"github.com/gofiber/fiber/v2\\\"\\n\\t\"github.com/gofiber/websocket/v2\\\"\\n)\\n\\nfunc main() {\\n\\tapp := fiber.New()\\n\\n\\tapp.Get(\\\"/\\\", func(c *fiber.Ctx) error {\\n\\t\\treturn c.SendString(\\\"Hello, World!\\\")\\n\\t})\\n\\n\\tapp.Get(\\\"/ws\\\", websocket.New(func(c *websocket.Conn) {\\n\\t\\tfor {\\n\\t\\t\\tmt, message, err := c.ReadMessage()\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"read error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t\\tlog.Printf(\\\"received: %s\\\", message)\\n\\t\\t\\terr = c.WriteMessage(mt, message)\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"write error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t}\\n\\t}))\\n\\n\\tlog.Fatal(app.Listen(\\\":3000\\\"))\\n}\",\n \"go.mod\": \"module fiber-websocket\\n\\ngo 1.16\\n\\nrequire (\\n\\tgithub.com/gofiber/fiber/v2 v2.12.1\\n\\tgithub.com/gofiber/websocket/v2 v2.10.2\\n)\",\n \"README.md\": \"# Golang Fiber WebSocket项目\\n\\n这个项目是一个使用 Golang 和 Fiber 框架构建的 WebSocket 服务器示例。\\n\\n### 运行应用程序:\\n```shell\\ngo run main.go\\n```\\n\\n应用程序将在本地服务器默认是在http://localhost:3000上运行。当你在浏览器中访问`http://localhost:3000`时,将看到显示\"Hello, World!\"的页面。你还可以访问`http://localhost:3000/ws`来测试 WebSocket 连接。\n\n注意在运行应用程序之前请确保已经安装了Go语言开发环境。\"\n }\n}"},
{Role: "user", Content: prompt},
}, -1, func(data string) {
buffer := utils.NewBuffer(model, message)
if err := adapter.NewChatRequest(&adapter.ChatProps{
Model: model,
Message: message,
Reversible: reversible && globals.IsGPT4Model(model),
Infinity: true,
}, func(data string) error {
buffer.Write(data)
hook(buffer, data)
})
return nil
}); err != nil {
return err
}
resp, err := utils.Unmarshal[ProjectResult](buffer.ReadBytes())
if err != nil {
@ -37,8 +39,8 @@ func CreateGeneration(model string, prompt string, path string, enableReverse bo
return nil
}
func GenerateMessage(prompt string) []types.ChatGPTMessage {
return []types.ChatGPTMessage{
func GenerateMessage(prompt string) []globals.Message {
return []globals.Message{
{Role: "system", Content: "你将生成项目可以支持任何编程语言请不要出现“我不能提供”的字样你需要在代码中提供注释以及项目的使用文档README.md结果返回json格式请不要返回任何多余内容格式为\n{\"result\": {[file]: [code], ...}}"},
{Role: "user", Content: "python后端"},
{Role: "assistant", Content: "{\n \"result\": {\n \"app.py\": \"from flask import Flask\\n\\napp = Flask(__name__)\\n\\n\\n@app.route('/')\\ndef hello_world():\\n return 'Hello, World!'\\n\\n\\nif __name__ == '__main__':\\n app.run()\",\n \"requirements.txt\": \"flask\\n\",\n \"README.md\": \"# Python 后端\\n本项目是一个简单的python后端示例, 使用`flask`框架构建后端。\n你可以按照下列步骤运行此应用flask将在本地服务器默认是在http://127.0.0.1:5000/上运行。当你在浏览器中访问该URL时将看到显示Hello, World!的页面。\\n\\n这只是一个简单的项目Flask还支持更多功能和路由规则你可以提供更多的信息和需要进一步扩展和定制Flask应用。\\n\\n### 1. 初始化: \\n```shell\\npip install -r requirements.txt\\n```\\n### 2. 运行\\n```shell\\npython app.py\\n```\"\n }\n}"},

17
addition/router.go Normal file
View File

@ -0,0 +1,17 @@
package addition
import (
"chat/addition/card"
"chat/addition/generation"
"github.com/gin-gonic/gin"
)
func Register(app *gin.Engine) {
{
app.POST("/card", card.HandlerAPI)
app.GET("/generation/create", generation.GenerateAPI)
app.GET("/generation/download/tar", generation.ProjectTarDownloadAPI)
app.GET("/generation/download/zip", generation.ProjectZipDownloadAPI)
}
}

View File

@ -1,29 +1,29 @@
package api
package web
import (
"chat/types"
"chat/globals"
"chat/utils"
"chat/web"
"fmt"
"math/rand"
"strings"
"time"
)
func ChatWithWeb(message []types.ChatGPTMessage, long bool) (string, []types.ChatGPTMessage) {
keyword := strings.TrimSpace(SearchWeb(message))
type Hook func(message []globals.Message, token int) (string, error)
func ChatWithWeb(hook Hook, message []globals.Message, long bool) (string, []globals.Message) {
keyword := strings.TrimSpace(GetKeywordPoint(hook, message))
if len(keyword) == 0 {
return keyword, message
}
data := web.SearchBing(keyword)
data := SearchBing(keyword)
if long {
data = utils.GetSegmentString(data, 6000)
} else {
data = utils.GetSegmentString(data, 3000)
}
return keyword, utils.Insert(message, 0, types.ChatGPTMessage{
return keyword, utils.Insert(message, 0, globals.Message{
Role: "system",
Content: fmt.Sprintf("你将扮演AI问答助手你的知识库不是截止于2021年9月而是可以实时联网的你可以提供实时联网的知识库。"+
"网站使用markdown链接包裹多使用emoji和颜文字给予用户精确的答复。"+
@ -34,12 +34,6 @@ func ChatWithWeb(message []types.ChatGPTMessage, long bool) (string, []types.Cha
})
}
func GetRandomKey(apikey string) string {
arr := strings.Split(apikey, "|")
idx := rand.Intn(len(arr))
return arr[idx]
}
func StringCleaner(content string) string {
for _, replacer := range []string{",", "、", "", "。", "", ":", "", ";", "", "!", "=", "", "?", "", "", "(", ")", "关键字", "空", "1+1"} {
content = strings.ReplaceAll(content, replacer, " ")
@ -47,8 +41,8 @@ func StringCleaner(content string) string {
return strings.TrimSpace(content)
}
func SearchWeb(message []types.ChatGPTMessage) string {
resp, _ := GetChatGPTResponse([]types.ChatGPTMessage{{
func GetKeywordPoint(hook Hook, message []globals.Message) string {
resp, _ := hook([]globals.Message{{
Role: "system",
Content: "If the user input content require ONLINE SEARCH to get the results, please output these keywords to refine the data Interval with space, remember not to answer other content, json format return, format {\"keyword\": \"...\" }",
}, {

39
addition/web/utils.go Normal file
View File

@ -0,0 +1,39 @@
package web
import (
"chat/adapter/chatgpt"
"chat/globals"
"chat/manager/conversation"
)
func UsingWebSegment(instance *conversation.Conversation) (string, []globals.Message) {
var keyword string
var segment []globals.Message
if instance.IsEnableWeb() {
keyword, segment = ChatWithWeb(func(message []globals.Message, token int) (string, error) {
return chatgpt.NewChatInstanceFromConfig("gpt3").CreateChatRequest(&chatgpt.ChatProps{
Model: globals.GPT3Turbo0613,
Message: message,
Token: token,
})
}, conversation.CopyMessage(instance.GetMessageSegment(12)), globals.IsLongContextModel(instance.GetModel()))
} else {
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
}
return keyword, segment
}
func UsingWebNativeSegment(enable bool, message []globals.Message) (string, []globals.Message) {
if enable {
return ChatWithWeb(func(message []globals.Message, token int) (string, error) {
return chatgpt.NewChatInstanceFromConfig("gpt3").CreateChatRequest(&chatgpt.ChatProps{
Model: globals.GPT3Turbo0613,
Message: message,
Token: token,
})
}, message, false)
} else {
return "", message
}
}

View File

@ -1,147 +0,0 @@
package api
import (
"chat/types"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/spf13/viper"
"net/http"
"strings"
"time"
)
type AnonymousRequestBody struct {
Message string `json:"message" required:"true"`
Web bool `json:"web"`
History []types.ChatGPTMessage `json:"history"`
}
type AnonymousResponseCache struct {
Keyword string `json:"keyword"`
Message string `json:"message"`
}
func GetChatGPTResponse(message []types.ChatGPTMessage, token int) (string, error) {
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer " + GetRandomKey(viper.GetString("openai.anonymous")),
}, types.ChatGPTRequest{
Model: "gpt-3.5-turbo",
Messages: message,
MaxToken: token,
})
if err != nil || res == nil {
return "", err
}
if res.(map[string]interface{})["choices"] == nil {
return res.(map[string]interface{})["error"].(map[string]interface{})["message"].(string), nil
}
data := res.(map[string]interface{})["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{})["content"]
return data.(string), nil
}
func TestKey(key string) bool {
res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer " + key,
}, types.ChatGPTRequest{
Model: "gpt-3.5-turbo",
Messages: []types.ChatGPTMessage{{Role: "user", Content: "hi"}},
MaxToken: 2,
})
if err != nil || res == nil {
panic(err)
}
return res.(map[string]interface{})["choices"] != nil
}
func GetAnonymousResponse(message []types.ChatGPTMessage, web bool) (string, string, error) {
if !web {
resp, err := GetChatGPTResponse(message, 1000)
return "", resp, err
}
keyword, source := ChatWithWeb(message, false)
resp, err := GetChatGPTResponse(source, 1000)
return keyword, resp, err
}
func GetSegmentMessage(data []types.ChatGPTMessage, length int) []types.ChatGPTMessage {
if len(data) <= length {
return data
}
return data[len(data)-length:]
}
func GetAnonymousMessage(message string, history []types.ChatGPTMessage) []types.ChatGPTMessage {
return append(
GetSegmentMessage(history, 5),
types.ChatGPTMessage{
Role: "user",
Content: strings.TrimSpace(message),
})
}
func GetAnonymousResponseWithCache(c *gin.Context, message string, web bool, history []types.ChatGPTMessage) (string, string, error) {
segment := GetAnonymousMessage(message, history)
hash := utils.Md5Encrypt(utils.ToJson(segment))
cache := c.MustGet("cache").(*redis.Client)
res, err := cache.Get(c, fmt.Sprintf(":chatgpt-%v:%s", web, hash)).Result()
form := utils.UnmarshalJson[AnonymousResponseCache](res)
if err != nil || len(res) == 0 || res == "{}" || form.Message == "" {
key, res, err := GetAnonymousResponse(segment, web)
if err != nil {
return "", "There was something wrong...", err
}
cache.Set(c, fmt.Sprintf(":chatgpt-%v:%s", web, hash), utils.ToJson(AnonymousResponseCache{
Keyword: key,
Message: res,
}), time.Hour*48)
return key, res, nil
}
return form.Keyword, form.Message, nil
}
func AnonymousAPI(c *gin.Context) {
var body AnonymousRequestBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "",
"keyword": "",
"reason": err.Error(),
})
return
}
message := strings.TrimSpace(body.Message)
if len(message) == 0 {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "",
"keyword": "",
"reason": "message is empty",
})
return
}
key, res, err := GetAnonymousResponseWithCache(c, message, body.Web, body.History)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": res,
"keyword": key,
"reason": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"status": true,
"message": res,
"keyword": key,
"reason": "",
})
}

View File

@ -1,214 +0,0 @@
package api
import (
"chat/auth"
"chat/conversation"
"chat/middleware"
"chat/types"
"chat/utils"
"database/sql"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"net/http"
"strings"
)
const defaultErrorMessage = "There was something wrong... Please try again later."
const defaultQuotaMessage = "You have run out of GPT-4 usage. Please keep your nio points above **5**. (**GPT-4-32K** requires **50** nio points)"
const defaultImageMessage = "Please provide description for the image (e.g. /image an apple)."
const maxThread = 5
type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"`
Id int64 `json:"id" binding:"required"`
}
func SendSegmentMessage(conn *websocket.Conn, message interface{}) {
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
}
func GetErrorQuota(model string) float32 {
if types.IsGPT4Model(model) {
return -0xe // special value for error
} else {
return 0
}
}
func GetTextSegment(instance *conversation.Conversation) (string, []types.ChatGPTMessage) {
var keyword string
var segment []types.ChatGPTMessage
if instance.IsEnableWeb() {
keyword, segment = ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
} else {
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
}
return keyword, segment
}
func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
keyword, segment := GetTextSegment(instance)
SendSegmentMessage(conn, types.ChatSegmentResponse{Keyword: keyword, End: false})
model := instance.GetModel()
useReverse := auth.CanEnableSubscription(db, cache, user)
if !auth.CanEnableModelWithSubscription(db, user, model, useReverse) {
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: defaultQuotaMessage,
Quota: 0,
End: true,
})
return defaultQuotaMessage
}
buffer := NewBuffer(model, segment)
StreamRequest(model, useReverse, segment,
utils.Multi(types.IsGPT4Model(model) || useReverse, -1, 2000),
func(resp string) {
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: buffer.Write(resp),
Quota: buffer.GetQuota(),
End: false,
})
})
if buffer.IsEmpty() {
if useReverse {
auth.DecreaseSubscriptionUsage(cache, user)
}
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: defaultErrorMessage,
Quota: GetErrorQuota(model),
End: true,
})
return defaultErrorMessage
}
// collect quota
if !useReverse {
user.UseQuota(db, buffer.GetQuota())
}
SendSegmentMessage(conn, types.ChatSegmentResponse{End: true, Quota: buffer.GetQuota()})
return buffer.ReadWithDefault(defaultErrorMessage)
}
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
// format: /image a cat
data := strings.TrimSpace(instance.GetLatestMessage()[6:])
if len(data) == 0 {
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: defaultImageMessage,
End: true,
})
return defaultImageMessage
}
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: "Generating image...\n",
End: false,
})
url, err := GetImageWithUserLimit(user, data, db, cache)
if err != nil {
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: err.Error(),
End: true,
})
return err.Error()
}
markdown := GetImageMarkdown(url)
SendSegmentMessage(conn, types.ChatSegmentResponse{
Quota: 1.,
Message: markdown,
End: true,
})
return markdown
}
func ChatHandler(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
return ImageChat(conn, instance, user, db, cache)
} else {
return TextChat(db, cache, user, conn, instance)
}
}
func ChatAPI(c *gin.Context) {
// websocket connection
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := c.Request.Header.Get("Origin")
if utils.Contains(origin, middleware.AllowedOrigins) {
return true
}
return false
},
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "",
"reason": err.Error(),
})
return
}
defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {
return
}
}(conn)
_, message, err := conn.ReadMessage()
if err != nil {
return
}
form, err := utils.Unmarshal[WebsocketAuthForm](message)
if err != nil {
return
}
user := auth.ParseToken(c, form.Token)
if user == nil {
return
}
db := c.MustGet("db").(*sql.DB)
cache := c.MustGet("cache").(*redis.Client)
var instance *conversation.Conversation
if form.Id == -1 {
// create new conversation
instance = conversation.NewConversation(db, user.GetID(db))
} else {
// load conversation
instance = conversation.LoadConversation(db, user.GetID(db), form.Id)
if instance == nil {
instance = conversation.NewConversation(db, user.GetID(db))
}
}
id := user.GetID(db)
for {
_, message, err = conn.ReadMessage()
if err != nil {
return
}
if instance.HandleMessage(db, message) {
if !utils.IncrWithLimit(cache, fmt.Sprintf(":chatthread:%d", id), 1, maxThread, 60) {
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", maxThread),
End: true,
})
return
}
msg := ChatHandler(conn, instance, user, db, cache)
utils.DecrInt(cache, fmt.Sprintf(":chatthread:%d", id), 1)
instance.SaveResponse(db, msg)
}
}
}

View File

@ -1,71 +0,0 @@
package api
import (
"chat/auth"
"chat/types"
"chat/utils"
"context"
"database/sql"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/spf13/viper"
"time"
)
func GetImage(prompt string) (string, error) {
res, err := utils.Post(viper.GetString("openai.image_endpoint")+"/images/generations", map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer " + GetRandomKey(viper.GetString("openai.image")),
}, types.ChatGPTImageRequest{
Prompt: prompt,
Size: "512x512",
N: 1,
})
if err != nil || res == nil {
return "", err
}
if err, ok := res.(map[string]interface{})["error"]; ok {
return "", fmt.Errorf(err.(map[string]interface{})["message"].(string))
}
data := res.(map[string]interface{})["data"].([]interface{})[0].(map[string]interface{})["url"]
return data.(string), nil
}
func GetImageWithCache(ctx context.Context, prompt string, cache *redis.Client) (string, error) {
res, err := cache.Get(ctx, fmt.Sprintf(":image:%s", prompt)).Result()
if err != nil || len(res) == 0 || res == "" {
res, err := GetImage(prompt)
if err != nil {
return "", err
}
cache.Set(ctx, fmt.Sprintf(":image:%s", prompt), res, time.Hour*6)
return res, nil
}
return res, nil
}
func GetLimitFormat(id int64) string {
today := time.Now().Format("2006-01-02")
return fmt.Sprintf(":imagelimit:%s:%d", today, id)
}
func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *redis.Client) (string, error) {
// free plan: 5 images per day
// pro plan: 50 images per day
key := GetLimitFormat(user.GetID(db))
usage := auth.GetDalleUsageLimit(db, user)
if utils.IncrWithLimit(cache, key, 1, int64(usage), 60*60*24) || auth.ReduceDalle(db, user) {
return GetImageWithCache(context.Background(), prompt, cache)
} else {
return "", fmt.Errorf("you have reached your limit of %d free images per day, please buy more quota or wait until tomorrow", usage)
}
}
func GetImageMarkdown(url string) string {
return fmt.Sprintln("![image](", url, ")")
}

View File

@ -1,24 +0,0 @@
package api
import "strings"
func FilterKeys(keys string) string {
stack := make(chan string, len(strings.Split(keys, "|")))
for _, key := range strings.Split(keys, "|") {
go func(key string) {
if TestKey(key) {
stack <- key
} else {
stack <- ""
}
}(key)
}
var result string
for i := 0; i < len(strings.Split(keys, "|")); i++ {
if res := <-stack; res != "" {
result += res + "|"
}
}
return strings.Trim(result, "|")
}

View File

@ -1,133 +0,0 @@
package api
import (
"chat/types"
"chat/utils"
"crypto/tls"
"encoding/json"
"fmt"
"github.com/spf13/viper"
"io"
"log"
"net/http"
"strings"
)
func processLine(buf []byte) []string {
data := strings.Trim(string(buf), "\n")
rep := strings.NewReplacer(
"data: {",
"\"data\": {",
)
data = rep.Replace(data)
array := strings.Split(data, "\n")
resp := make([]string, 0)
for _, item := range array {
item = strings.TrimSpace(item)
if !strings.HasPrefix(item, "{") {
item = "{" + item
}
if !strings.HasSuffix(item, "}}") {
item = item + "}"
}
if item == "{data: [DONE]}" || item == "{data: [DONE]}}" || item == "{[DONE]}" {
break
} else if item == "{data:}" || item == "{data:}}" {
continue
}
var form types.ChatGPTStreamResponse
if err := json.Unmarshal([]byte(item), &form); err != nil {
if err := json.Unmarshal([]byte(item[:len(item)-1]), &form); err != nil {
log.Println(item, err)
}
}
choices := form.Data.Choices
if len(choices) > 0 {
resp = append(resp, choices[0].Delta.Content)
}
}
return resp
}
func MixRequestBody(model string, messages []types.ChatGPTMessage, token int) interface{} {
if token == -1 {
return types.ChatGPTRequestWithInfinity{
Model: model,
Messages: messages,
Stream: true,
}
}
return types.ChatGPTRequest{
Model: model,
Messages: messages,
MaxToken: token,
Stream: true,
}
}
func NativeStreamRequest(model string, endpoint string, apikeys string, messages []types.ChatGPTMessage, token int, callback func(string)) {
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
client := &http.Client{}
req, err := http.NewRequest("POST", endpoint+"/chat/completions", utils.ConvertBody(MixRequestBody(model, messages, token)))
if err != nil {
fmt.Println(err)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+GetRandomKey(apikeys))
res, err := client.Do(req)
if err != nil {
fmt.Println(fmt.Sprintf("[stream] error: %s (status: %d)", err.Error(), res.StatusCode))
return
} else if res.StatusCode >= 400 || res.StatusCode < 200 || res == nil {
fmt.Println(fmt.Sprintf("[stream] request failed (status: %d)", res.StatusCode))
return
}
defer res.Body.Close()
for {
buf := make([]byte, 20480)
n, err := res.Body.Read(buf)
if err == io.EOF {
break
}
if err != nil {
log.Println(err)
}
for _, item := range processLine(buf[:n]) {
callback(item)
}
}
}
func StreamRequest(model string, enableReverse bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
switch model {
case types.GPT4,
types.GPT40314,
types.GPT40613:
if enableReverse {
NativeStreamRequest(viper.GetString("openai.reverse"), viper.GetString("openai.pro_endpoint"), viper.GetString("openai.pro"), messages, token, callback)
} else {
NativeStreamRequest(types.GPT40613, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
}
case types.GPT432k,
types.GPT432k0613,
types.GPT432k0314:
NativeStreamRequest(types.GPT432k0613, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
case types.GPT3Turbo16k,
types.GPT3Turbo16k0301,
types.GPT3Turbo16k0613:
NativeStreamRequest(types.GPT3Turbo16k0613, viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback)
case types.Claude2,
types.Claude2100k:
NativeStreamRequest(model, viper.GetString("claude.endpoint"), viper.GetString("claude.key"), messages, token, callback)
default:
NativeStreamRequest(types.GPT3Turbo0613, viper.GetString("openai.anonymous_endpoint"), viper.GetString("openai.anonymous"), messages, token, callback)
}
}

13
auth/router.go Normal file
View File

@ -0,0 +1,13 @@
package auth
import "github.com/gin-gonic/gin"
func Register(app *gin.Engine) {
app.POST("/login", LoginAPI)
app.POST("/state", StateAPI)
app.GET("/package", PackageAPI)
app.GET("/quota", QuotaAPI)
app.POST("/buy", BuyAPI)
app.GET("/subscription", SubscriptionAPI)
app.POST("/subscribe", SubscribeAPI)
}

View File

@ -38,6 +38,9 @@ func DecreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
}
func CanEnableSubscription(db *sql.DB, cache *redis.Client, user *User) bool {
if user == nil {
return false
}
return user.IsSubscribe(db) && IncreaseSubscriptionUsage(cache, user)
}

View File

@ -1,7 +1,7 @@
package auth
import (
"chat/types"
"chat/globals"
"chat/utils"
"database/sql"
)
@ -27,17 +27,17 @@ import (
// ¥0.13 / per image
// 1 nio / per image
func CountInputToken(model string, v []types.ChatGPTMessage) float32 {
func CountInputToken(model string, v []globals.Message) float32 {
switch model {
case types.GPT3Turbo:
case globals.GPT3Turbo:
return 0
case types.GPT3Turbo16k:
case globals.GPT3Turbo16k:
return 0
case types.GPT4:
case globals.GPT4:
return float32(utils.CountTokenPrice(v, model)) / 1000 * 2.1
case types.GPT432k:
case globals.GPT432k:
return float32(utils.CountTokenPrice(v, model)) / 1000 * 4.2
case types.Claude2, types.Claude2100k:
case globals.Claude2, globals.Claude2100k:
return 0
default:
return 0
@ -46,15 +46,15 @@ func CountInputToken(model string, v []types.ChatGPTMessage) float32 {
func CountOutputToken(model string, t int) float32 {
switch model {
case types.GPT3Turbo:
case globals.GPT3Turbo:
return 0
case types.GPT3Turbo16k:
case globals.GPT3Turbo16k:
return 0
case types.GPT4:
case globals.GPT4:
return float32(t*utils.GetWeightByModel(model)) / 1000 * 4.3
case types.GPT432k:
case globals.GPT432k:
return float32(t*utils.GetWeightByModel(model)) / 1000 * 8.6
case types.Claude2, types.Claude2100k:
case globals.Claude2, globals.Claude2100k:
return 0
default:
return 0
@ -70,17 +70,17 @@ func ReduceDalle(db *sql.DB, user *User) bool {
func CanEnableModel(db *sql.DB, user *User, model string) bool {
switch model {
case types.GPT4, types.GPT40613, types.GPT40314:
return user.GetQuota(db) >= 5
case types.GPT432k, types.GPT432k0613, types.GPT432k0314:
return user.GetQuota(db) >= 50
case globals.GPT4, globals.GPT40613, globals.GPT40314:
return user != nil && user.GetQuota(db) >= 5
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
return user != nil && user.GetQuota(db) >= 50
default:
return true
}
}
func CanEnableModelWithSubscription(db *sql.DB, user *User, model string, useReverse bool) bool {
if utils.Contains(model, types.GPT4Array) {
if utils.Contains(model, globals.GPT4Array) {
if useReverse {
return true
}

View File

@ -25,6 +25,22 @@ type LoginForm struct {
Token string `form:"token" binding:"required"`
}
func GetUser(c *gin.Context) *User {
if c.GetBool("auth") {
return &User{
Username: c.GetString("user"),
}
}
return nil
}
func GetId(db *sql.DB, user *User) int64 {
if user == nil {
return -1
}
return user.GetID(db)
}
func (u *User) Validate(c *gin.Context) bool {
if u.Username == "" || u.Password == "" {
return false

View File

@ -1,126 +0,0 @@
package generation
import (
"chat/api"
"chat/auth"
"chat/middleware"
"chat/types"
"chat/utils"
"database/sql"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"net/http"
"strings"
)
type WebsocketGenerationForm struct {
Token string `json:"token" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
Model string `json:"model" binding:"required"`
}
func ProjectTarDownloadAPI(c *gin.Context) {
hash := strings.TrimSpace(c.Query("hash"))
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.tar.gz")
c.File(fmt.Sprintf("generation/data/out/%s.tar.gz", hash))
}
func ProjectZipDownloadAPI(c *gin.Context) {
hash := strings.TrimSpace(c.Query("hash"))
c.Writer.Header().Add("Content-Disposition", "attachment; filename=code.zip")
c.File(fmt.Sprintf("generation/data/out/%s.zip", hash))
}
func GenerateAPI(c *gin.Context) {
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := c.Request.Header.Get("Origin")
if utils.Contains(origin, middleware.AllowedOrigins) {
return true
}
return false
},
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "",
"reason": err.Error(),
})
return
}
defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {
return
}
}(conn)
_, message, err := conn.ReadMessage()
if err != nil {
return
}
form, err := utils.Unmarshal[WebsocketGenerationForm](message)
if err != nil {
return
}
user := auth.ParseToken(c, form.Token)
if user == nil {
return
}
db := c.MustGet("db").(*sql.DB)
cache := c.MustGet("cache").(*redis.Client)
id := user.GetID(db)
if !utils.IncrWithLimit(cache, fmt.Sprintf(":generation:%d", id), 1, 30, 3600) {
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: true,
Error: "generation rate limit exceeded, the max generation rate is 30 per hour.",
})
return
}
useReverse := auth.CanEnableSubscription(db, cache, user)
if !auth.CanEnableModelWithSubscription(db, user, form.Model, useReverse) {
api.SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: "You don't have enough quota to use this model.",
Quota: 0,
End: true,
})
return
}
var instance *api.Buffer
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(buffer *api.Buffer, data string) {
instance = buffer
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: false,
Message: data,
Quota: buffer.GetQuota(),
})
})
if instance != nil && !useReverse {
user.UseQuota(db, instance.GetQuota())
}
if err != nil {
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: true,
Error: err.Error(),
Quota: instance.GetQuota(),
})
return
}
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: true,
Hash: hash,
Quota: instance.GetQuota(),
})
}

21
globals/types.go Normal file
View File

@ -0,0 +1,21 @@
package globals
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatSegmentResponse struct {
Quota float32 `json:"quota"`
Keyword string `json:"keyword"`
Message string `json:"message"`
End bool `json:"end"`
}
type GenerationSegmentResponse struct {
Quota float32 `json:"quota"`
Message string `json:"message"`
Hash string `json:"hash"`
End bool `json:"end"`
Error string `json:"error"`
}

View File

@ -1,4 +1,16 @@
package types
package globals
const ChatMaxThread = 5
const AnonymousMaxThread = 1
var AllowedOrigins = []string{
"https://fystart.cn",
"https://www.fystart.cn",
"https://nio.fystart.cn",
"https://chatnio.net",
"https://www.chatnio.net",
"http://localhost:5173",
}
const (
GPT3Turbo = "gpt-3.5-turbo"
@ -47,6 +59,17 @@ var ClaudeModelArray = []string{
Claude2100k,
}
var LongContextModelArray = []string{
GPT3Turbo16k,
GPT3Turbo16k0613,
GPT3Turbo16k0301,
GPT432k,
GPT432k0314,
GPT432k0613,
Claude2,
Claude2100k,
}
func in(value string, slice []string) bool {
for _, item := range slice {
if item == value {
@ -60,10 +83,22 @@ func IsGPT4Model(model string) bool {
return in(model, GPT4Array) || in(model, GPT432kArray)
}
func IsGPT4NativeModel(model string) bool {
return in(model, GPT4Array)
}
func IsGPT3TurboModel(model string) bool {
return in(model, GPT3TurboArray) || in(model, GPT3Turbo16kArray)
}
func IsChatGPTModel(model string) bool {
return IsGPT3TurboModel(model) || IsGPT4Model(model) || model == Dalle
}
func IsClaudeModel(model string) bool {
return in(model, ClaudeModelArray)
}
func IsLongContextModel(model string) bool {
return in(model, LongContextModelArray)
}

10
go.mod
View File

@ -7,7 +7,12 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/go-redis/redis/v8 v8.11.5
github.com/go-sql-driver/mysql v1.7.1
github.com/google/uuid v1.3.1
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/russross/blackfriday/v2 v2.1.0
github.com/spf13/viper v1.16.0
golang.org/x/net v0.10.0
)
require (
@ -23,8 +28,6 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
@ -35,8 +38,6 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pkoukk/tiktoken-go v0.1.5 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/spf13/afero v1.9.5 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
@ -46,7 +47,6 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect

2
go.sum
View File

@ -149,8 +149,6 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=

45
main.go
View File

@ -1,12 +1,13 @@
package main
import (
"chat/api"
"chat/addition"
"chat/auth"
"chat/connection"
"chat/conversation"
"chat/generation"
"chat/manager"
"chat/manager/conversation"
"chat/middleware"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
)
@ -18,35 +19,17 @@ func main() {
}
app := gin.Default()
middleware.RegisterMiddleware(app)
{
app.Use(middleware.CORSMiddleware())
app.Use(middleware.BuiltinMiddleWare(connection.ConnectMySQL(), connection.ConnectRedis()))
app.Use(middleware.ThrottleMiddleware())
app.Use(auth.Middleware())
auth.Register(app)
manager.Register(app)
addition.Register(app)
conversation.Register(app)
}
app.POST("/anonymous", api.AnonymousAPI)
app.POST("/card", api.CardAPI)
app.GET("/chat", api.ChatAPI)
app.POST("/login", auth.LoginAPI)
app.POST("/state", auth.StateAPI)
app.GET("/package", auth.PackageAPI)
app.GET("/quota", auth.QuotaAPI)
app.POST("/buy", auth.BuyAPI)
app.GET("/subscription", auth.SubscriptionAPI)
app.POST("/subscribe", auth.SubscribeAPI)
app.GET("/conversation/list", conversation.ListAPI)
app.GET("/conversation/load", conversation.LoadAPI)
app.GET("/conversation/delete", conversation.DeleteAPI)
app.GET("/generation/create", generation.GenerateAPI)
app.GET("/generation/download/tar", generation.ProjectTarDownloadAPI)
app.GET("/generation/download/zip", generation.ProjectZipDownloadAPI)
}
if viper.GetBool("debug") {
gin.SetMode(gin.DebugMode)
} else {
gin.SetMode(gin.ReleaseMode)
}
if err := app.Run(":" + viper.GetString("server.port")); err != nil {
gin.SetMode(utils.Multi[string](viper.GetBool("debug"), gin.DebugMode, gin.ReleaseMode))
if err := app.Run(fmt.Sprintf(":%s", viper.GetString("server.port"))); err != nil {
panic(err)
}
}

34
manager/cache.go Normal file
View File

@ -0,0 +1,34 @@
package manager
import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"time"
)
type CacheProps struct {
Message []globals.Message `json:"message" required:"true"`
Model string `json:"model" required:"true"`
Reversible bool `json:"reversible"`
}
type CacheData struct {
Keyword string `json:"keyword"`
Message string `json:"message"`
}
func ExtractCacheData(c *gin.Context, props *CacheProps) *CacheData {
hash := utils.Md5Encrypt(utils.Marshal(props))
data, err := utils.GetCacheFromContext(c).Get(c, fmt.Sprintf(":niodata:%s", hash)).Result()
if err == nil && data != "" {
return utils.UnmarshalForm[CacheData](data)
}
return nil
}
func SaveCacheData(c *gin.Context, props *CacheProps, data *CacheData) {
hash := utils.Md5Encrypt(utils.Marshal(props))
utils.GetCacheFromContext(c).Set(c, fmt.Sprintf(":niodata:%s", hash), utils.Marshal(data), time.Hour*12)
}

103
manager/chat.go Normal file
View File

@ -0,0 +1,103 @@
package manager
import (
"chat/adapter"
"chat/addition/web"
"chat/auth"
"chat/globals"
"chat/manager/conversation"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
)
const defaultMessage = "Sorry, I don't understand. Please try again."
const defaultQuotaMessage = "You don't have enough quota to use this model. Please buy more quota to continue." +
"| GPT-4 | GPT-4-32k " +
"| 5 nio | 50 nio "
func GetErrorQuota(model string) float32 {
return utils.Multi[float32](globals.IsGPT4Model(model), -0xe, 0) // special value for error
}
func CollectQuota(c *gin.Context, user *auth.User, quota float32, reversible bool) {
db := utils.GetDBFromContext(c)
if !reversible && quota > 0 && user != nil {
user.UseQuota(db, quota)
}
}
func ChatHandler(conn *utils.WebSocket, user *auth.User, instance *conversation.Conversation) string {
defer func() {
if err := recover(); err != nil {
fmt.Println(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
err, instance.GetModel(), conn.GetCtx().ClientIP(),
))
}
}()
keyword, segment := web.UsingWebSegment(instance)
conn.Send(globals.ChatSegmentResponse{Keyword: keyword, End: false})
model := instance.GetModel()
db := conn.GetDB()
cache := conn.GetCache()
reversible := auth.CanEnableSubscription(db, cache, user)
if !auth.CanEnableModelWithSubscription(db, user, model, reversible) {
conn.Send(globals.ChatSegmentResponse{
Message: defaultQuotaMessage,
Quota: 0,
End: true,
})
return defaultQuotaMessage
}
if form := ExtractCacheData(conn.GetCtx(), &CacheProps{
Message: segment,
Model: model,
Reversible: reversible,
}); form != nil {
conn.Send(globals.ChatSegmentResponse{
Message: form.Message,
Quota: 0,
End: true,
})
return form.Message
}
buffer := utils.NewBuffer(model, segment)
if err := adapter.NewChatRequest(&adapter.ChatProps{
Model: model,
Message: segment,
Reversible: reversible && globals.IsGPT4Model(model),
}, func(data string) error {
return conn.SendJSON(globals.ChatSegmentResponse{
Message: buffer.Write(data),
Quota: buffer.GetQuota(),
End: false,
})
}); err != nil {
CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible)
conn.Send(globals.ChatSegmentResponse{
Message: err.Error(),
Quota: GetErrorQuota(model),
End: true,
})
return err.Error()
}
CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible)
conn.Send(globals.ChatSegmentResponse{End: true, Quota: buffer.GetQuota()})
SaveCacheData(conn.GetCtx(), &CacheProps{
Message: segment,
Model: model,
Reversible: reversible,
}, &CacheData{
Keyword: keyword,
Message: buffer.ReadWithDefault(defaultMessage),
})
return buffer.ReadWithDefault(defaultMessage)
}

65
manager/completions.go Normal file
View File

@ -0,0 +1,65 @@
package manager
import (
"chat/adapter"
"chat/addition/web"
"chat/auth"
"chat/globals"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
)
func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, string, float32) {
defer func() {
if err := recover(); err != nil {
fmt.Println(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
err, model, c.ClientIP(),
))
}
}()
keyword, segment := web.UsingWebNativeSegment(enableWeb, message)
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
reversible := auth.CanEnableSubscription(db, cache, user)
if !auth.CanEnableModelWithSubscription(db, user, model, reversible) {
return keyword, defaultQuotaMessage, 0
}
if form := ExtractCacheData(c, &CacheProps{
Message: segment,
Model: model,
Reversible: reversible,
}); form != nil {
return form.Keyword, form.Message, 0
}
buffer := utils.NewBuffer(model, segment)
if err := adapter.NewChatRequest(&adapter.ChatProps{
Model: model,
Reversible: reversible && globals.IsGPT4Model(model),
Message: segment,
}, func(resp string) error {
buffer.Write(resp)
return nil
}); err != nil {
CollectQuota(c, user, buffer.GetQuota(), reversible)
return keyword, err.Error(), GetErrorQuota(model)
}
CollectQuota(c, user, buffer.GetQuota(), reversible)
SaveCacheData(c, &CacheProps{
Message: segment,
Model: model,
Reversible: reversible,
}, &CacheData{
Keyword: keyword,
Message: buffer.ReadWithDefault(defaultMessage),
})
return keyword, buffer.ReadWithDefault(defaultMessage), buffer.GetQuota()
}

View File

@ -1,7 +1,8 @@
package conversation
import (
"chat/types"
"chat/auth"
"chat/globals"
"chat/utils"
"database/sql"
"errors"
@ -9,12 +10,13 @@ import (
)
type Conversation struct {
UserID int64 `json:"user_id"`
Id int64 `json:"id"`
Name string `json:"name"`
Message []types.ChatGPTMessage `json:"message"`
Model string `json:"model"`
EnableWeb bool `json:"enable_web"`
Auth bool `json:"auth"`
UserID int64 `json:"user_id"`
Id int64 `json:"id"`
Name string `json:"name"`
Message []globals.Message `json:"message"`
Model string `json:"model"`
EnableWeb bool `json:"enable_web"`
}
type FormMessage struct {
@ -24,17 +26,48 @@ type FormMessage struct {
Model string `json:"model"`
}
func NewAnonymousConversation() *Conversation {
return &Conversation{
Auth: false,
UserID: -1,
Id: -1,
Name: "anonymous",
Message: []globals.Message{},
Model: globals.GPT3Turbo,
EnableWeb: false,
}
}
func NewConversation(db *sql.DB, id int64) *Conversation {
return &Conversation{
Auth: true,
UserID: id,
Id: GetConversationLengthByUserID(db, id) + 1,
Name: "new chat",
Message: []types.ChatGPTMessage{},
Model: types.GPT3Turbo,
Message: []globals.Message{},
Model: globals.GPT3Turbo,
EnableWeb: false,
}
}
func ExtractConversation(db *sql.DB, user *auth.User, id int64) *Conversation {
if user == nil {
return NewAnonymousConversation()
}
if id == -1 {
// create new conversation
return NewConversation(db, user.GetID(db))
}
// load conversation
if instance := LoadConversation(db, user.GetID(db), id); instance != nil {
return instance
} else {
return NewConversation(db, user.GetID(db))
}
}
func (c *Conversation) GetModel() string {
return c.Model
}
@ -72,7 +105,7 @@ func (c *Conversation) SetId(id int64) {
c.Id = id
}
func (c *Conversation) GetMessage() []types.ChatGPTMessage {
func (c *Conversation) GetMessage() []globals.Message {
return c.Message
}
@ -80,41 +113,41 @@ func (c *Conversation) GetMessageSize() int {
return len(c.Message)
}
func (c *Conversation) GetMessageSegment(length int) []types.ChatGPTMessage {
func (c *Conversation) GetMessageSegment(length int) []globals.Message {
if length > len(c.Message) {
return c.Message
}
return c.Message[len(c.Message)-length:]
}
func CopyMessage(message []types.ChatGPTMessage) []types.ChatGPTMessage {
return utils.UnmarshalJson[[]types.ChatGPTMessage](utils.ToJson(message)) // deep copy
func CopyMessage(message []globals.Message) []globals.Message {
return utils.UnmarshalJson[[]globals.Message](utils.ToJson(message)) // deep copy
}
func (c *Conversation) GetLastMessage() types.ChatGPTMessage {
func (c *Conversation) GetLastMessage() globals.Message {
return c.Message[len(c.Message)-1]
}
func (c *Conversation) AddMessage(message types.ChatGPTMessage) {
func (c *Conversation) AddMessage(message globals.Message) {
c.Message = append(c.Message, message)
}
func (c *Conversation) AddMessageFromUser(message string) {
c.AddMessage(types.ChatGPTMessage{
c.AddMessage(globals.Message{
Role: "user",
Content: message,
})
}
func (c *Conversation) AddMessageFromAssistant(message string) {
c.AddMessage(types.ChatGPTMessage{
c.AddMessage(globals.Message{
Role: "assistant",
Content: message,
})
}
func (c *Conversation) AddMessageFromSystem(message string) {
c.AddMessage(types.ChatGPTMessage{
c.AddMessage(globals.Message{
Role: "system",
Content: message,
})
@ -132,7 +165,7 @@ func GetMessage(data []byte) (string, error) {
return form.Message, nil
}
func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
func (c *Conversation) AddMessageFromByte(data []byte) (string, error) {
form, err := utils.Unmarshal[FormMessage](data)
if err != nil {
return "", err
@ -146,9 +179,32 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
return form.Message, nil
}
func (c *Conversation) HandleMessage(db *sql.DB, data []byte) bool {
func (c *Conversation) AddMessageFromForm(form *FormMessage) error {
if len(form.Message) == 0 {
return errors.New("message is empty")
}
c.AddMessageFromUser(form.Message)
c.SetModel(form.Model)
c.SetEnableWeb(form.Web)
return nil
}
func (c *Conversation) HandleMessage(db *sql.DB, form *FormMessage) bool {
head := len(c.Message) == 0
msg, err := c.AddMessageFromUserForm(data)
if err := c.AddMessageFromForm(form); err != nil {
return false
}
if head {
c.SetName(db, form.Message)
}
c.SaveConversation(db)
return true
}
func (c *Conversation) HandleMessageFromByte(db *sql.DB, data []byte) bool {
head := len(c.Message) == 0
msg, err := c.AddMessageFromByte(data)
if err != nil {
return false
}

View File

@ -0,0 +1,12 @@
package conversation
import "github.com/gin-gonic/gin"
func Register(app *gin.Engine) {
router := app.Group("/conversation")
{
router.GET("/list", ListAPI)
router.GET("/load", LoadAPI)
router.GET("/delete", DeleteAPI)
}
}

View File

@ -1,13 +1,18 @@
package conversation
import (
"chat/types"
"chat/globals"
"chat/utils"
"database/sql"
"log"
)
func (c *Conversation) SaveConversation(db *sql.DB) bool {
if c.UserID == -1 {
// anonymous request
return true
}
data := utils.ToJson(c.GetMessage())
query := `INSERT INTO conversation (user_id, conversation_id, conversation_name, data) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE conversation_name = VALUES(conversation_name), data = VALUES(data)`
@ -49,7 +54,7 @@ func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversat
return nil
}
conversation.Message, err = utils.Unmarshal[[]types.ChatGPTMessage]([]byte(data))
conversation.Message, err = utils.Unmarshal[[]globals.Message]([]byte(data))
if err != nil {
return nil
}

49
manager/image.go Normal file
View File

@ -0,0 +1,49 @@
package manager
import (
"chat/adapter/chatgpt"
"chat/auth"
"chat/utils"
"database/sql"
"fmt"
"github.com/gin-gonic/gin"
"strings"
"time"
)
func GetImageLimitFormat(db *sql.DB, user *auth.User) string {
return fmt.Sprintf(":imagelimit:%s:%d", time.Now().Format("2006-01-02"), user.GetID(db))
}
func GenerateImage(c *gin.Context, user *auth.User, prompt string) (string, error) {
// free plan: 5 images per day
// pro plan: 50 images per day
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
key := GetImageLimitFormat(db, user)
usage := auth.GetDalleUsageLimit(db, user)
prompt = strings.TrimSpace(prompt)
if len(prompt) == 0 {
return "", fmt.Errorf("please provide description for the image (e.g. /image an apple)")
}
if utils.IncrWithLimit(cache, key, 1, int64(usage), 60*60*24) || auth.ReduceDalle(db, user) {
instance := chatgpt.NewChatInstanceFromModel(&chatgpt.InstanceProps{
Model: "dalle",
})
response, err := instance.CreateImage(chatgpt.ImageProps{
Prompt: prompt,
})
if err != nil {
return "", err
} else {
return utils.GetImageMarkdown(response), nil
}
} else {
return "", fmt.Errorf("you have reached your limit of %d free images per day, please buy more quota or wait until tomorrow", usage)
}
}

100
manager/manager.go Normal file
View File

@ -0,0 +1,100 @@
package manager
import (
"chat/auth"
"chat/globals"
"chat/manager/conversation"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"strconv"
"strings"
)
type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"`
Id int64 `json:"id" binding:"required"`
}
func EventHandler(conn *utils.WebSocket, instance *conversation.Conversation, user *auth.User) string {
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
if user == nil {
conn.Send(globals.ChatSegmentResponse{
Message: "You need to login to use this feature.",
End: true,
})
return "You need to login to use this feature."
}
prompt := strings.TrimSpace(strings.TrimPrefix(instance.GetLatestMessage(), "/image"))
if response, err := GenerateImage(conn.GetCtx(), user, prompt); err != nil {
conn.Send(globals.ChatSegmentResponse{
Message: err.Error(),
End: true,
})
return err.Error()
} else {
conn.Send(globals.ChatSegmentResponse{
Quota: 1.,
Message: response,
End: true,
})
return response
}
} else {
return ChatHandler(conn, user, instance)
}
}
func ChatAPI(c *gin.Context) {
var conn *utils.WebSocket
if conn = utils.NewWebsocket(c); conn == nil {
return
}
defer conn.DeferClose()
db := utils.GetDBFromContext(c)
var form *WebsocketAuthForm
if form = utils.ReadForm[WebsocketAuthForm](conn); form == nil {
return
}
user := auth.ParseToken(c, form.Token)
authenticated := user != nil
id := auth.GetId(db, user)
instance := conversation.ExtractConversation(db, user, id)
hash := fmt.Sprintf(":chatthread:%s", utils.Md5Encrypt(utils.Multi(
authenticated,
strconv.FormatInt(id, 10),
c.ClientIP(),
)))
for {
var form *conversation.FormMessage
if form := utils.ReadForm[conversation.FormMessage](conn); form == nil {
return
}
if instance.HandleMessage(db, form) {
if !conn.IncrRateWithLimit(
hash,
utils.Multi[int64](authenticated, globals.ChatMaxThread, globals.AnonymousMaxThread),
60,
) {
conn.Send(globals.ChatSegmentResponse{
Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", globals.ChatMaxThread),
End: true,
})
return
}
response := EventHandler(conn, instance, user)
conn.DecrRate(hash)
instance.SaveResponse(db, response)
}
}
}

7
manager/router.go Normal file
View File

@ -0,0 +1,7 @@
package manager
import "github.com/gin-gonic/gin"
func Register(app *gin.Engine) {
app.GET("/chat", ChatAPI)
}

View File

@ -1,15 +1,16 @@
package auth
package middleware
import (
"chat/auth"
"github.com/gin-gonic/gin"
"strings"
)
func Middleware() gin.HandlerFunc {
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
token := strings.TrimSpace(c.GetHeader("Authorization"))
if token != "" {
if user := ParseToken(c, token); user != nil {
if user := auth.ParseToken(c, token); user != nil {
c.Set("token", token)
c.Set("auth", true)
c.Set("user", user.Username)
@ -24,16 +25,3 @@ func Middleware() gin.HandlerFunc {
c.Next()
}
}
func GetToken(c *gin.Context) string {
return c.GetString("token")
}
func GetUser(c *gin.Context) *User {
if c.GetBool("auth") {
return &User{
Username: c.GetString("user"),
}
}
return nil
}

View File

@ -1,24 +1,16 @@
package middleware
import (
"chat/globals"
"chat/utils"
"github.com/gin-gonic/gin"
"net/http"
)
var AllowedOrigins = []string{
"https://fystart.cn",
"https://www.fystart.cn",
"https://nio.fystart.cn",
"https://chatnio.net",
"https://www.chatnio.net",
"http://localhost:5173",
}
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
if utils.Contains(origin, AllowedOrigins) {
if utils.Contains(origin, globals.AllowedOrigins) {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")

13
middleware/middleware.go Normal file
View File

@ -0,0 +1,13 @@
package middleware
import (
"chat/connection"
"github.com/gin-gonic/gin"
)
func RegisterMiddleware(app *gin.Engine) {
app.Use(CORSMiddleware())
app.Use(BuiltinMiddleWare(connection.ConnectMySQL(), connection.ConnectRedis()))
app.Use(ThrottleMiddleware())
app.Use(AuthMiddleware())
}

View File

@ -1,55 +0,0 @@
package types
type ChatGPTMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatGPTRequest struct {
Model string `json:"model"`
Messages []ChatGPTMessage `json:"messages"`
MaxToken int `json:"max_tokens"`
Stream bool `json:"stream"`
}
type ChatGPTRequestWithInfinity struct {
Model string `json:"model"`
Messages []ChatGPTMessage `json:"messages"`
Stream bool `json:"stream"`
}
type ChatGPTImageRequest struct {
Prompt string `json:"prompt"`
Size string `json:"size"`
N int `json:"n"`
}
type ChatGPTStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Data struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
}
Index int `json:"index"`
} `json:"choices"`
} `json:"data"`
}
type ChatSegmentResponse struct {
Quota float32 `json:"quota"`
Keyword string `json:"keyword"`
Message string `json:"message"`
End bool `json:"end"`
}
type GenerationSegmentResponse struct {
Quota float32 `json:"quota"`
Message string `json:"message"`
Hash string `json:"hash"`
End bool `json:"end"`
Error string `json:"error"`
}

View File

@ -1,8 +1,8 @@
package api
package utils
import (
"chat/auth"
"chat/types"
"chat/globals"
)
type Buffer struct {
@ -13,7 +13,7 @@ type Buffer struct {
Times int `json:"times"`
}
func NewBuffer(model string, history []types.ChatGPTMessage) *Buffer {
func NewBuffer(model string, history []globals.Message) *Buffer {
return &Buffer{
Data: "",
Cursor: 0,

View File

@ -2,6 +2,7 @@ package utils
import (
"encoding/json"
"fmt"
"math/rand"
"strconv"
"time"
@ -37,6 +38,14 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
return form, err
}
func UnmarshalForm[T interface{}](data string) *T {
form, err := Unmarshal[T]([]byte(data))
if err != nil {
return nil
}
return &form
}
func Marshal[T interface{}](data T) string {
res, err := json.Marshal(data)
if err != nil {
@ -65,3 +74,7 @@ func ToInt(value string) int {
func ConvertSqlTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
func GetImageMarkdown(url string) string {
return fmt.Sprintf("![image](%s)", url)
}

View File

@ -3,12 +3,13 @@ package utils
import (
"database/sql"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
func GetDBFromContext(c *gin.Context) *sql.DB {
return c.MustGet("db").(*sql.DB)
}
func GetCacheFromContext(c *gin.Context) *sql.DB {
return c.MustGet("cache").(*sql.DB)
func GetCacheFromContext(c *gin.Context) *redis.Client {
return c.MustGet("cache").(*redis.Client)
}

12
utils/key.go Normal file
View File

@ -0,0 +1,12 @@
package utils
import (
"math/rand"
"strings"
)
func GetRandomKey(apikey string) string {
arr := strings.Split(apikey, "|")
idx := rand.Intn(len(arr))
return arr[idx]
}

View File

@ -2,10 +2,12 @@ package utils
import (
"bytes"
"crypto/tls"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
)
func Http(uri string, method string, ptr interface{}, headers map[string]string, body io.Reader) (err error) {
@ -100,3 +102,43 @@ func PostForm(uri string, body map[string]interface{}) (data map[string]interfac
return data, nil
}
func EventSource(method string, uri string, headers map[string]string, body interface{}, callback func(string) error) error {
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
client := &http.Client{}
req, err := http.NewRequest(method, uri, ConvertBody(body))
if err != nil {
return nil
}
for key, value := range headers {
req.Header.Set(key, value)
}
res, err := client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
for {
buf := make([]byte, 20480)
n, err := res.Body.Read(buf)
if err == io.EOF {
return nil
} else if err != nil {
return err
}
data := string(buf[:n])
for _, item := range strings.Split(data, "\n") {
segment := strings.TrimSpace(item)
if len(segment) > 0 {
if err := callback(segment); err != nil {
return err
}
}
}
}
}

View File

@ -1,7 +1,7 @@
package utils
import (
"chat/types"
"chat/globals"
"fmt"
"github.com/pkoukk/tiktoken-go"
"strings"
@ -13,42 +13,42 @@ import (
func GetWeightByModel(model string) int {
switch model {
case types.Claude2,
types.Claude2100k:
case globals.Claude2,
globals.Claude2100k:
return 2
case types.GPT432k,
types.GPT432k0613,
types.GPT432k0314:
case globals.GPT432k,
globals.GPT432k0613,
globals.GPT432k0314:
return 3 * 10
case types.GPT3Turbo,
types.GPT3Turbo0613,
case globals.GPT3Turbo,
globals.GPT3Turbo0613,
types.GPT3Turbo16k,
types.GPT3Turbo16k0613,
globals.GPT3Turbo16k,
globals.GPT3Turbo16k0613,
types.GPT4,
types.GPT40314,
types.GPT40613:
globals.GPT4,
globals.GPT40314,
globals.GPT40613:
return 3
case types.GPT3Turbo0301, types.GPT3Turbo16k0301:
case globals.GPT3Turbo0301, globals.GPT3Turbo16k0301:
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
default:
if strings.Contains(model, types.GPT3Turbo) {
if strings.Contains(model, globals.GPT3Turbo) {
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
return GetWeightByModel(types.GPT3Turbo0613)
} else if strings.Contains(model, types.GPT4) {
return GetWeightByModel(globals.GPT3Turbo0613)
} else if strings.Contains(model, globals.GPT4) {
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
return GetWeightByModel(types.GPT40613)
} else if strings.Contains(model, types.Claude2) {
return GetWeightByModel(globals.GPT40613)
} else if strings.Contains(model, globals.Claude2) {
// warning: claude-2 may update over time. Returning num tokens assuming claude-2-100k.
return GetWeightByModel(types.Claude2100k)
return GetWeightByModel(globals.Claude2100k)
} else {
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
panic(fmt.Errorf("not implemented for model %s", model))
}
}
}
func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) {
func NumTokensFromMessages(messages []globals.Message, model string) (tokens int) {
weight := GetWeightByModel(model)
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
@ -68,6 +68,6 @@ func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (token
return tokens
}
func CountTokenPrice(messages []types.ChatGPTMessage, model string) int {
func CountTokenPrice(messages []globals.Message, model string) int {
return NumTokensFromMessages(messages, model)
}

132
utils/websocket.go Normal file
View File

@ -0,0 +1,132 @@
package utils
import (
"chat/globals"
"database/sql"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"io"
"net/http"
)
type WebSocket struct {
Ctx *gin.Context
Conn *websocket.Conn
}
func CheckUpgrader(c *gin.Context) *websocket.Upgrader {
return &websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := c.Request.Header.Get("Origin")
if Contains(origin, globals.AllowedOrigins) {
return true
}
return false
},
}
}
func NewWebsocket(c *gin.Context) *WebSocket {
upgrader := CheckUpgrader(c)
if conn, err := upgrader.Upgrade(c.Writer, c.Request, nil); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "",
"reason": err.Error(),
})
return nil
} else {
return &WebSocket{
Ctx: c,
Conn: conn,
}
}
}
func (w *WebSocket) Read() (int, []byte, error) {
return w.Conn.ReadMessage()
}
func (w *WebSocket) Write(messageType int, data []byte) error {
return w.Conn.WriteMessage(messageType, data)
}
func (w *WebSocket) Close() error {
return w.Conn.Close()
}
func (w *WebSocket) DeferClose() {
if err := w.Close(); err != nil {
return
}
}
func (w *WebSocket) NextWriter(messageType int) (io.WriteCloser, error) {
return w.Conn.NextWriter(messageType)
}
func (w *WebSocket) ReadJSON(v interface{}) error {
return w.Conn.ReadJSON(v)
}
func (w *WebSocket) SendJSON(v interface{}) error {
return w.Conn.WriteJSON(v)
}
func (w *WebSocket) Send(v interface{}) bool {
return w.SendJSON(v) == nil
}
func (w *WebSocket) Receive(v interface{}) bool {
return w.ReadJSON(v) == nil
}
func (w *WebSocket) SendText(message string) bool {
return w.Write(websocket.TextMessage, []byte(message)) == nil
}
func (w *WebSocket) DecrRate(key string) bool {
cache := w.GetCache()
return DecrInt(cache, key, 1)
}
func (w *WebSocket) IncrRate(key string) bool {
cache := w.GetCache()
_, err := Incr(cache, key, 1)
return err == nil
}
func (w *WebSocket) IncrRateWithLimit(key string, limit int64, expiration int64) bool {
cache := w.GetCache()
return IncrWithLimit(cache, key, 1, limit, expiration)
}
func (w *WebSocket) GetCtx() *gin.Context {
return w.Ctx
}
func (w *WebSocket) GetDB() *sql.DB {
return GetDBFromContext(w.Ctx)
}
func (w *WebSocket) GetCache() *redis.Client {
return GetCacheFromContext(w.Ctx)
}
func ReadForm[T comparable](w *WebSocket) *T {
// golang cannot use generic type in class-like struct
// except ping
_, message, err := w.Read()
if err != nil {
return nil
} else if string(message) == "{\"type\":\"ping\"}" {
return ReadForm[T](w)
}
form, err := Unmarshal[T](message)
if err != nil {
return nil
}
return &form
}