mirror of
https://github.com/coaidev/coai.git
synced 2025-05-29 01:40:17 +09:00
update generation quota counter
This commit is contained in:
parent
3d165bf42d
commit
6577e54a2e
@ -95,23 +95,32 @@ func GenerateAPI(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(data string) {
|
||||
var instance *api.Buffer
|
||||
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(buffer *api.Buffer, data string) {
|
||||
instance = buffer
|
||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||
End: false,
|
||||
Message: data,
|
||||
Quota: buffer.GetQuota(),
|
||||
})
|
||||
})
|
||||
|
||||
if instance != nil && !useReverse {
|
||||
user.UseQuota(db, instance.GetQuota())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Error: err.Error(),
|
||||
Quota: instance.GetQuota(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||
End: true,
|
||||
Hash: hash,
|
||||
End: true,
|
||||
Hash: hash,
|
||||
Quota: instance.GetQuota(),
|
||||
})
|
||||
}
|
||||
|
@ -1,16 +1,15 @@
|
||||
package generation
|
||||
|
||||
import (
|
||||
"chat/api"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(data string)) (string, error) {
|
||||
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *api.Buffer, data string)) (string, error) {
|
||||
hash, path := GetFolderByHash(model, prompt)
|
||||
if !utils.Exists(path) {
|
||||
if err := CreateGeneration(model, prompt, path, enableReverse, func(data string) {
|
||||
hook(data)
|
||||
}); err != nil {
|
||||
if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil {
|
||||
fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error()))
|
||||
return "", fmt.Errorf("error during generate project: %s", err.Error())
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ type ProjectResult struct {
|
||||
Result map[string]interface{} `json:"result"`
|
||||
}
|
||||
|
||||
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(data string)) error {
|
||||
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(buffer *api.Buffer, data string)) error {
|
||||
message := GenerateMessage(prompt)
|
||||
buffer := api.NewBuffer(model, message)
|
||||
api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{
|
||||
@ -22,8 +22,8 @@ func CreateGeneration(model string, prompt string, path string, enableReverse bo
|
||||
{Role: "assistant", Content: "{\n \"result\": {\n \"main.go\": \"package main\\n\\nimport (\\n\\t\"log\\\"\\n\\n\\t\"github.com/gofiber/fiber/v2\\\"\\n\\t\"github.com/gofiber/websocket/v2\\\"\\n)\\n\\nfunc main() {\\n\\tapp := fiber.New()\\n\\n\\tapp.Get(\\\"/\\\", func(c *fiber.Ctx) error {\\n\\t\\treturn c.SendString(\\\"Hello, World!\\\")\\n\\t})\\n\\n\\tapp.Get(\\\"/ws\\\", websocket.New(func(c *websocket.Conn) {\\n\\t\\tfor {\\n\\t\\t\\tmt, message, err := c.ReadMessage()\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"read error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t\\tlog.Printf(\\\"received: %s\\\", message)\\n\\t\\t\\terr = c.WriteMessage(mt, message)\\n\\t\\t\\tif err != nil {\\n\\t\\t\\t\\tlog.Println(\\\"write error:\\\", err)\\n\\t\\t\\t\\tbreak\\n\\t\\t\\t}\\n\\t\\t}\\n\\t}))\\n\\n\\tlog.Fatal(app.Listen(\\\":3000\\\"))\\n}\",\n \"go.mod\": \"module fiber-websocket\\n\\ngo 1.16\\n\\nrequire (\\n\\tgithub.com/gofiber/fiber/v2 v2.12.1\\n\\tgithub.com/gofiber/websocket/v2 v2.10.2\\n)\",\n \"README.md\": \"# Golang Fiber WebSocket项目\\n\\n这个项目是一个使用 Golang 和 Fiber 框架构建的 WebSocket 服务器示例。\\n\\n### 运行应用程序:\\n```shell\\ngo run main.go\\n```\\n\\n应用程序将在本地服务器(默认是在http://localhost:3000)上运行。当你在浏览器中访问`http://localhost:3000`时,将看到显示\"Hello, World!\"的页面。你还可以访问`http://localhost:3000/ws`来测试 WebSocket 连接。\n\n注意:在运行应用程序之前,请确保已经安装了Go语言开发环境。\"\n }\n}"},
|
||||
{Role: "user", Content: prompt},
|
||||
}, -1, func(data string) {
|
||||
hook(data)
|
||||
buffer.Write(data)
|
||||
hook(buffer, data)
|
||||
})
|
||||
|
||||
resp, err := utils.Unmarshal[ProjectResult](buffer.ReadBytes())
|
||||
|
Loading…
Reference in New Issue
Block a user