update generation quota counter

This commit is contained in:
Zhang Minghan 2023-09-24 09:38:59 +08:00
parent 3d165bf42d
commit 6577e54a2e
3 changed files with 17 additions and 9 deletions

View File

@ -95,23 +95,32 @@ func GenerateAPI(c *gin.Context) {
return
}
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(data string) {
var instance *api.Buffer
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(buffer *api.Buffer, data string) {
instance = buffer
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: false,
Message: data,
Quota: buffer.GetQuota(),
})
})
if instance != nil && !useReverse {
user.UseQuota(db, instance.GetQuota())
}
if err != nil {
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: true,
Error: err.Error(),
Quota: instance.GetQuota(),
})
return
}
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: true,
Hash: hash,
End: true,
Hash: hash,
Quota: instance.GetQuota(),
})
}

View File

@ -1,16 +1,15 @@
package generation
import (
"chat/api"
"chat/utils"
"fmt"
)
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(data string)) (string, error) {
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *api.Buffer, data string)) (string, error) {
hash, path := GetFolderByHash(model, prompt)
if !utils.Exists(path) {
if err := CreateGeneration(model, prompt, path, enableReverse, func(data string) {
hook(data)
}); err != nil {
if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil {
fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error()))
return "", fmt.Errorf("error during generate project: %s", err.Error())
}

View File

@ -11,7 +11,7 @@ type ProjectResult struct {
Result map[string]interface{} `json:"result"`
}
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(data string)) error {
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(buffer *api.Buffer, data string)) error {
message := GenerateMessage(prompt)
buffer := api.NewBuffer(model, message)
api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{
@ -22,8 +22,8 @@ func CreateGeneration(model string, prompt string, path string, enableReverse bo
{Role: "assistant", Content: "{\n \"result\": {\n \"main.go\": \"package main\\n\\nimport (\\n\\t\"log\\\"\\n\\n\\t\"github.com/gofiber/fiber/v2\\\"\\n\\t\"github.com/gofiber/websocket/v2\\\"\\n)\\n\\nfunc main() {\\n\\tapp := fiber.New()\\n\\n\\tapp.Get(\\\"/\\\", func(c *fiber.Ctx) error {\\n\\t\\treturn c.SendString(\\\"Hello, World!\\\")\\n\\t})\\n\\n\\tapp.Get(\\\"/ws\\\", websocket.New(func(c *websocket.Conn) {\\n\\t\\tfor {\\n\\t\\t\\tmt, message, err := c.ReadMessage()\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"read error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t\\tlog.Printf(\\\"received: %s\\\", message)\\n\\t\\t\\terr = c.WriteMessage(mt, message)\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"write error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t}\\n\\t}))\\n\\n\\tlog.Fatal(app.Listen(\\\":3000\\\"))\\n}\",\n \"go.mod\": \"module fiber-websocket\\n\\ngo 1.16\\n\\nrequire (\\n\\tgithub.com/gofiber/fiber/v2 v2.12.1\\n\\tgithub.com/gofiber/websocket/v2 v2.10.2\\n)\",\n \"README.md\": \"# Golang Fiber WebSocket项目\\n\\n这个项目是一个使用 Golang 和 Fiber 框架构建的 WebSocket 服务器示例。\\n\\n### 运行应用程序:\\n```shell\\ngo run main.go\\n```\\n\\n应用程序将在本地服务器默认是在http://localhost:3000上运行。当你在浏览器中访问`http://localhost:3000`时,将看到显示\"Hello, World!\"的页面。你还可以访问`http://localhost:3000/ws`来测试 WebSocket 连接。\n\n注意在运行应用程序之前请确保已经安装了Go语言开发环境。\"\n }\n}"},
{Role: "user", Content: prompt},
}, -1, func(data string) {
hook(data)
buffer.Write(data)
hook(buffer, data)
})
resp, err := utils.Unmarshal[ProjectResult](buffer.ReadBytes())