diff --git a/generation/api.go b/generation/api.go index 9e8fcd1..ef2dbbb 100644 --- a/generation/api.go +++ b/generation/api.go @@ -95,23 +95,32 @@ func GenerateAPI(c *gin.Context) { return } - hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(data string) { + var instance *api.Buffer + hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(buffer *api.Buffer, data string) { + instance = buffer api.SendSegmentMessage(conn, types.GenerationSegmentResponse{ End: false, Message: data, + Quota: buffer.GetQuota(), }) }) + if instance != nil && !useReverse { + user.UseQuota(db, instance.GetQuota()) + } + if err != nil { api.SendSegmentMessage(conn, types.GenerationSegmentResponse{ End: true, Error: err.Error(), + Quota: instance.GetQuota(), }) return } api.SendSegmentMessage(conn, types.GenerationSegmentResponse{ - End: true, - Hash: hash, + End: true, + Hash: hash, + Quota: instance.GetQuota(), }) } diff --git a/generation/generate.go b/generation/generate.go index 621f133..ce07aae 100644 --- a/generation/generate.go +++ b/generation/generate.go @@ -1,16 +1,15 @@ package generation import ( + "chat/api" "chat/utils" "fmt" ) -func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(data string)) (string, error) { +func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *api.Buffer, data string)) (string, error) { hash, path := GetFolderByHash(model, prompt) if !utils.Exists(path) { - if err := CreateGeneration(model, prompt, path, enableReverse, func(data string) { - hook(data) - }); err != nil { + if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil { fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error())) return "", fmt.Errorf("error during generate project: %s", err.Error()) } diff --git a/generation/prompt.go b/generation/prompt.go index 31713cb..47e683a 100644 --- a/generation/prompt.go +++ b/generation/prompt.go @@ -11,7 +11,7 @@ type ProjectResult struct { Result map[string]interface{} `json:"result"` } -func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(data string)) error { +func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(buffer *api.Buffer, data string)) error { message := GenerateMessage(prompt) buffer := api.NewBuffer(model, message) api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{ @@ -22,8 +22,8 @@ func CreateGeneration(model string, prompt string, path string, enableReverse bo {Role: "assistant", Content: "{\n \"result\": {\n \"main.go\": \"package main\\n\\nimport (\\n\\t\"log\\\"\\n\\n\\t\"github.com/gofiber/fiber/v2\\\"\\n\\t\"github.com/gofiber/websocket/v2\\\"\\n)\\n\\nfunc main() {\\n\\tapp := fiber.New()\\n\\n\\tapp.Get(\\\"/\\\", func(c *fiber.Ctx) error {\\n\\t\\treturn c.SendString(\\\"Hello, World!\\\")\\n\\t})\\n\\n\\tapp.Get(\\\"/ws\\\", websocket.New(func(c *websocket.Conn) {\\n\\t\\tfor {\\n\\t\\t\\tmt, message, err := c.ReadMessage()\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"read error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t\\tlog.Printf(\\\"received: %s\\\", message)\\n\\t\\t\\terr = c.WriteMessage(mt, message)\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"write error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t}\\n\\t}))\\n\\n\\tlog.Fatal(app.Listen(\\\":3000\\\"))\\n}\",\n \"go.mod\": \"module fiber-websocket\\n\\ngo 1.16\\n\\nrequire (\\n\\tgithub.com/gofiber/fiber/v2 v2.12.1\\n\\tgithub.com/gofiber/websocket/v2 v2.10.2\\n)\",\n \"README.md\": \"# Golang Fiber WebSocket项目\\n\\n这个项目是一个使用 Golang 和 Fiber 框架构建的 WebSocket 服务器示例。\\n\\n### 运行应用程序:\\n```shell\\ngo run main.go\\n```\\n\\n应用程序将在本地服务器(默认是在http://localhost:3000)上运行。当你在浏览器中访问`http://localhost:3000`时,将看到显示\"Hello, World!\"的页面。你还可以访问`http://localhost:3000/ws`来测试 WebSocket 连接。\n\n注意:在运行应用程序之前,请确保已经安装了Go语言开发环境。\"\n }\n}"}, {Role: "user", Content: prompt}, }, -1, func(data string) { - hook(data) buffer.Write(data) + hook(buffer, data) }) resp, err := utils.Unmarshal[ProjectResult](buffer.ReadBytes())