mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
86 lines
2.3 KiB
Go
86 lines
2.3 KiB
Go
package article
|
|
|
|
import (
|
|
"chat/auth"
|
|
"chat/globals"
|
|
"chat/manager"
|
|
"chat/utils"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"strings"
|
|
)
|
|
|
|
type StreamProgressResponse struct {
|
|
Current int `json:"current"`
|
|
Total int `json:"total"`
|
|
Quota float32 `json:"quota"`
|
|
}
|
|
|
|
type Response struct {
|
|
File string
|
|
Quota float32
|
|
}
|
|
|
|
func GenerateArticle(c *gin.Context, user *auth.User, model string, hash string, title string, prompt string, enableWeb bool) Response {
|
|
message, quota := manager.NativeChatHandler(c, user, model, []globals.Message{{
|
|
Role: globals.User,
|
|
Content: fmt.Sprintf("%s\n%s", prompt, title),
|
|
}}, enableWeb)
|
|
|
|
return Response{
|
|
File: CreateArticleFile(hash, title, message),
|
|
Quota: quota,
|
|
}
|
|
}
|
|
|
|
func ParseTitle(titles string) []string {
|
|
var result []string
|
|
for _, title := range strings.Split(titles, "\n") {
|
|
title = strings.TrimSpace(title)
|
|
if len(title) > 0 {
|
|
result = append(result, title)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func CreateGenerationWorker(c *gin.Context, user *auth.User, model string, prompt string, title string, enableWeb bool, hash string) (int, chan Response) {
|
|
titles := ParseTitle(title)
|
|
result := make(chan Response, len(titles))
|
|
|
|
for _, name := range titles {
|
|
go func(title string) {
|
|
result <- GenerateArticle(c, user, model, hash, title, prompt, enableWeb)
|
|
}(name)
|
|
}
|
|
|
|
return len(titles), result
|
|
}
|
|
|
|
func CreateWorker(c *gin.Context, user *auth.User, model string, prompt string, title string, enableWeb bool, hook func(resp StreamProgressResponse)) string {
|
|
hash := utils.Md5Encrypt(fmt.Sprintf("%s%s%s%v", model, prompt, title, enableWeb))
|
|
total, channel := CreateGenerationWorker(c, user, model, prompt, title, enableWeb, hash)
|
|
current := 0
|
|
|
|
hook(StreamProgressResponse{Current: current, Total: total, Quota: 0})
|
|
|
|
for resp := range channel {
|
|
current += 1
|
|
hook(StreamProgressResponse{Current: current, Total: total, Quota: resp.Quota})
|
|
|
|
if current == total {
|
|
break
|
|
}
|
|
}
|
|
|
|
hook(StreamProgressResponse{Current: current, Total: total, Quota: 0})
|
|
|
|
path := fmt.Sprintf("storage/article/data/%s", hash)
|
|
if _, _, err := utils.GenerateCompressTask(hash, "storage/article", path, path); err != nil {
|
|
globals.Debug(fmt.Sprintf("[article] error during generate compress task: %s", err.Error()))
|
|
return ""
|
|
}
|
|
|
|
return hash
|
|
}
|