update generation quota counter

This commit is contained in:
Zhang Minghan 2023-09-24 09:38:59 +08:00
parent 3d165bf42d
commit 6577e54a2e
3 changed files with 17 additions and 9 deletions

View File

@ -95,23 +95,32 @@ func GenerateAPI(c *gin.Context) {
return return
} }
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(data string) { var instance *api.Buffer
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(buffer *api.Buffer, data string) {
instance = buffer
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{ api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: false, End: false,
Message: data, Message: data,
Quota: buffer.GetQuota(),
}) })
}) })
if instance != nil && !useReverse {
user.UseQuota(db, instance.GetQuota())
}
if err != nil { if err != nil {
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{ api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: true, End: true,
Error: err.Error(), Error: err.Error(),
Quota: instance.GetQuota(),
}) })
return return
} }
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{ api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: true, End: true,
Hash: hash, Hash: hash,
Quota: instance.GetQuota(),
}) })
} }

View File

@ -1,16 +1,15 @@
package generation package generation
import ( import (
"chat/api"
"chat/utils" "chat/utils"
"fmt" "fmt"
) )
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(data string)) (string, error) { func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *api.Buffer, data string)) (string, error) {
hash, path := GetFolderByHash(model, prompt) hash, path := GetFolderByHash(model, prompt)
if !utils.Exists(path) { if !utils.Exists(path) {
if err := CreateGeneration(model, prompt, path, enableReverse, func(data string) { if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil {
hook(data)
}); err != nil {
fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error())) fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error()))
return "", fmt.Errorf("error during generate project: %s", err.Error()) return "", fmt.Errorf("error during generate project: %s", err.Error())
} }

View File

@ -11,7 +11,7 @@ type ProjectResult struct {
Result map[string]interface{} `json:"result"` Result map[string]interface{} `json:"result"`
} }
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(data string)) error { func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(buffer *api.Buffer, data string)) error {
message := GenerateMessage(prompt) message := GenerateMessage(prompt)
buffer := api.NewBuffer(model, message) buffer := api.NewBuffer(model, message)
api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{ api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{
@ -22,8 +22,8 @@ func CreateGeneration(model string, prompt string, path string, enableReverse bo
{Role: "assistant", Content: "{\n \"result\": {\n \"main.go\": \"package main\\n\\nimport (\\n\\t\"log\\\"\\n\\n\\t\"github.com/gofiber/fiber/v2\\\"\\n\\t\"github.com/gofiber/websocket/v2\\\"\\n)\\n\\nfunc main() {\\n\\tapp := fiber.New()\\n\\n\\tapp.Get(\\\"/\\\", func(c *fiber.Ctx) error {\\n\\t\\treturn c.SendString(\\\"Hello, World!\\\")\\n\\t})\\n\\n\\tapp.Get(\\\"/ws\\\", websocket.New(func(c *websocket.Conn) {\\n\\t\\tfor {\\n\\t\\t\\tmt, message, err := c.ReadMessage()\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"read error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t\\tlog.Printf(\\\"received: %s\\\", message)\\n\\t\\t\\terr = c.WriteMessage(mt, message)\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"write error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t}\\n\\t}))\\n\\n\\tlog.Fatal(app.Listen(\\\":3000\\\"))\\n}\",\n \"go.mod\": \"module fiber-websocket\\n\\ngo 1.16\\n\\nrequire (\\n\\tgithub.com/gofiber/fiber/v2 v2.12.1\\n\\tgithub.com/gofiber/websocket/v2 v2.10.2\\n)\",\n \"README.md\": \"# Golang Fiber WebSocket项目\\n\\n这个项目是一个使用 Golang 和 Fiber 框架构建的 WebSocket 服务器示例。\\n\\n### 运行应用程序:\\n```shell\\ngo run main.go\\n```\\n\\n应用程序将在本地服务器默认是在http://localhost:3000上运行。当你在浏览器中访问`http://localhost:3000`时,将看到显示\"Hello, World!\"的页面。你还可以访问`http://localhost:3000/ws`来测试 WebSocket 连接。\n\n注意在运行应用程序之前请确保已经安装了Go语言开发环境。\"\n }\n}"}, {Role: "assistant", Content: "{\n \"result\": {\n \"main.go\": \"package main\\n\\nimport (\\n\\t\"log\\\"\\n\\n\\t\"github.com/gofiber/fiber/v2\\\"\\n\\t\"github.com/gofiber/websocket/v2\\\"\\n)\\n\\nfunc main() {\\n\\tapp := fiber.New()\\n\\n\\tapp.Get(\\\"/\\\", func(c *fiber.Ctx) error {\\n\\t\\treturn c.SendString(\\\"Hello, World!\\\")\\n\\t})\\n\\n\\tapp.Get(\\\"/ws\\\", websocket.New(func(c *websocket.Conn) {\\n\\t\\tfor {\\n\\t\\t\\tmt, message, err := c.ReadMessage()\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"read error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t\\tlog.Printf(\\\"received: %s\\\", message)\\n\\t\\t\\terr = c.WriteMessage(mt, message)\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"write error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t}\\n\\t}))\\n\\n\\tlog.Fatal(app.Listen(\\\":3000\\\"))\\n}\",\n \"go.mod\": \"module fiber-websocket\\n\\ngo 1.16\\n\\nrequire (\\n\\tgithub.com/gofiber/fiber/v2 v2.12.1\\n\\tgithub.com/gofiber/websocket/v2 v2.10.2\\n)\",\n \"README.md\": \"# Golang Fiber WebSocket项目\\n\\n这个项目是一个使用 Golang 和 Fiber 框架构建的 WebSocket 服务器示例。\\n\\n### 运行应用程序:\\n```shell\\ngo run main.go\\n```\\n\\n应用程序将在本地服务器默认是在http://localhost:3000上运行。当你在浏览器中访问`http://localhost:3000`时,将看到显示\"Hello, World!\"的页面。你还可以访问`http://localhost:3000/ws`来测试 WebSocket 连接。\n\n注意在运行应用程序之前请确保已经安装了Go语言开发环境。\"\n }\n}"},
{Role: "user", Content: prompt}, {Role: "user", Content: prompt},
}, -1, func(data string) { }, -1, func(data string) {
hook(data)
buffer.Write(data) buffer.Write(data)
hook(buffer, data)
}) })
resp, err := utils.Unmarshal[ProjectResult](buffer.ReadBytes()) resp, err := utils.Unmarshal[ProjectResult](buffer.ReadBytes())