mirror of
https://github.com/coaidev/coai.git
synced 2025-05-29 09:50:16 +09:00
feat: support model cache buffer feature (support cache expiration
cache accepted models
cache size
)
This commit is contained in:
parent
b52f7b6617
commit
780a55a109
@ -44,7 +44,7 @@ import {
|
||||
} from "@/components/ui/dialog.tsx";
|
||||
import { DialogTitle } from "@radix-ui/react-dialog";
|
||||
import Require from "@/components/Require.tsx";
|
||||
import { Loader2 } from "lucide-react";
|
||||
import { Loader2, Settings2 } from "lucide-react";
|
||||
import { Textarea } from "@/components/ui/textarea.tsx";
|
||||
import Tips from "@/components/Tips.tsx";
|
||||
import { cn } from "@/components/ui/lib/utils.ts";
|
||||
@ -52,6 +52,7 @@ import { Switch } from "@/components/ui/switch.tsx";
|
||||
import { MultiCombobox } from "@/components/ui/multi-combobox.tsx";
|
||||
import { allGroups } from "@/utils/groups.ts";
|
||||
import { channelModels } from "@/admin/channel.ts";
|
||||
import { supportModels } from "@/conf";
|
||||
|
||||
type CompProps<T> = {
|
||||
data: T;
|
||||
@ -592,6 +593,46 @@ function Common({ data, dispatch, onChange }: CompProps<CommonState>) {
|
||||
min={0}
|
||||
/>
|
||||
</ParagraphItem>
|
||||
<ParagraphItem>
|
||||
<div className={`flex flex-row flex-wrap gap-2 ml-auto`}>
|
||||
<Button
|
||||
variant={`outline`}
|
||||
onClick={() => dispatch({ type: "update:common.cache", value: [] })}
|
||||
>
|
||||
<Settings2
|
||||
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
|
||||
/>
|
||||
{t("admin.system.cacheNone")}
|
||||
</Button>
|
||||
<Button
|
||||
variant={`outline`}
|
||||
onClick={() =>
|
||||
dispatch({
|
||||
type: "update:common.cache",
|
||||
value: supportModels
|
||||
.filter((item) => item.free)
|
||||
.map((item) => item.id),
|
||||
})
|
||||
}
|
||||
>
|
||||
<Settings2
|
||||
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
|
||||
/>
|
||||
{t("admin.system.cacheFree")}
|
||||
</Button>
|
||||
<Button
|
||||
variant={`outline`}
|
||||
onClick={() =>
|
||||
dispatch({ type: "update:common.cache", value: channelModels })
|
||||
}
|
||||
>
|
||||
<Settings2
|
||||
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
|
||||
/>
|
||||
{t("admin.system.cacheAll")}
|
||||
</Button>
|
||||
</div>
|
||||
</ParagraphItem>
|
||||
<ParagraphSpace />
|
||||
<ParagraphItem>
|
||||
<Label className={`flex flex-row items-center`}>
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error {
|
||||
@ -28,3 +30,50 @@ func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) e
|
||||
globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
|
||||
return err
|
||||
}
|
||||
|
||||
func PreflightCache(cache *redis.Client, hash string, buffer *utils.Buffer, hook globals.Hook) (int64, bool, error) {
|
||||
idx := utils.Intn64(globals.CacheAcceptedSize)
|
||||
key := fmt.Sprintf("chat-cache:%d:%s", idx, hash)
|
||||
|
||||
raw, err := cache.Get(cache.Context(), key).Result()
|
||||
if err != nil {
|
||||
return idx, false, nil
|
||||
}
|
||||
|
||||
buf, err := utils.UnmarshalString[utils.Buffer](raw)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return idx, false, nil
|
||||
}
|
||||
|
||||
data := buf.Read()
|
||||
if data == "" {
|
||||
return idx, false, nil
|
||||
}
|
||||
|
||||
buffer.SetToolCalls(buf.GetToolCalls())
|
||||
return idx, true, hook(data)
|
||||
}
|
||||
|
||||
func StoreCache(cache *redis.Client, hash string, index int64, buffer *utils.Buffer) {
|
||||
key := fmt.Sprintf("chat-cache:%d:%s", index, hash)
|
||||
raw := utils.Marshal(buffer)
|
||||
expire := time.Duration(globals.CacheAcceptedExpire) * time.Second
|
||||
|
||||
cache.Set(cache.Context(), key, raw, expire)
|
||||
}
|
||||
|
||||
func NewChatRequestWithCache(cache *redis.Client, buffer *utils.Buffer, group string, props *adapter.ChatProps, hook globals.Hook) (bool, error) {
|
||||
hash := utils.Md5Encrypt(utils.Marshal(props))
|
||||
idx, hit, err := PreflightCache(cache, hash, buffer, hook)
|
||||
if hit {
|
||||
return true, err
|
||||
}
|
||||
|
||||
if err = NewChatRequest(group, props, hook); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
StoreCache(cache, hash, idx, buffer)
|
||||
return false, nil
|
||||
}
|
||||
|
@ -19,15 +19,16 @@ const defaultMessage = "empty response"
|
||||
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
|
||||
db := utils.GetDBFromContext(c)
|
||||
quota := buffer.GetQuota()
|
||||
if buffer.IsEmpty() {
|
||||
return
|
||||
} else if buffer.GetCharge().IsBillingType(globals.TimesBilling) && err != nil {
|
||||
// billing type is times, but error occurred
|
||||
|
||||
if user == nil || quota <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// collect quota for tokens billing (though error occurred) or times billing
|
||||
if !uncountable && quota > 0 && user != nil {
|
||||
if buffer.IsEmpty() || err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !uncountable {
|
||||
user.UseQuota(db, quota)
|
||||
}
|
||||
}
|
||||
@ -92,7 +93,8 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
||||
err := channel.NewChatRequest(
|
||||
hit, err := channel.NewChatRequestWithCache(
|
||||
cache, buffer,
|
||||
auth.GetGroup(db, user),
|
||||
&adapter.ChatProps{
|
||||
Model: model,
|
||||
@ -125,7 +127,6 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||
|
||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
Message: err.Error(),
|
||||
End: true,
|
||||
@ -133,7 +134,9 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
if !hit {
|
||||
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
|
||||
}
|
||||
|
||||
if buffer.IsEmpty() {
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
|
@ -89,7 +89,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||
err := channel.NewChatRequest(auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
||||
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
||||
buffer.Write(data)
|
||||
return nil
|
||||
})
|
||||
@ -103,7 +103,10 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
|
||||
return
|
||||
}
|
||||
|
||||
if !hit {
|
||||
CollectQuota(c, user, buffer, plan, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, RelayResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", id),
|
||||
Object: "chat.completion",
|
||||
@ -158,7 +161,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
|
||||
|
||||
go func() {
|
||||
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||
err := channel.NewChatRequest(auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
||||
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
||||
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
|
||||
return nil
|
||||
})
|
||||
@ -173,7 +176,11 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
|
||||
}
|
||||
|
||||
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
|
||||
|
||||
if !hit {
|
||||
CollectQuota(c, user, buffer, plan, err)
|
||||
}
|
||||
|
||||
close(partial)
|
||||
return
|
||||
}()
|
||||
|
@ -40,7 +40,8 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
||||
err := channel.NewChatRequest(
|
||||
hit, err := channel.NewChatRequestWithCache(
|
||||
cache, buffer,
|
||||
auth.GetGroup(db, user),
|
||||
&adapter.ChatProps{
|
||||
Model: model,
|
||||
@ -56,11 +57,12 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil {
|
||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||
CollectQuota(c, user, buffer, plan, err)
|
||||
return err.Error(), 0
|
||||
}
|
||||
|
||||
if !hit {
|
||||
CollectQuota(c, user, buffer, plan, err)
|
||||
}
|
||||
|
||||
SaveCacheData(c, &CacheProps{
|
||||
Message: segment,
|
||||
|
@ -89,7 +89,7 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||
err := channel.NewChatRequest(auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data string) error {
|
||||
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data string) error {
|
||||
buffer.Write(data)
|
||||
return nil
|
||||
})
|
||||
@ -103,7 +103,9 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
|
||||
return
|
||||
}
|
||||
|
||||
if !hit {
|
||||
CollectQuota(c, user, buffer, plan, err)
|
||||
}
|
||||
|
||||
image := getUrlFromBuffer(buffer)
|
||||
if image == "" {
|
||||
|
@ -13,6 +13,12 @@ func Intn(n int) int {
|
||||
return r.Intn(n)
|
||||
}
|
||||
|
||||
func Intn64(n int64) int64 {
|
||||
source := rand.NewSource(time.Now().UnixNano())
|
||||
r := rand.New(source)
|
||||
return r.Int63n(n)
|
||||
}
|
||||
|
||||
func IntnSeed(n int, seed int) int {
|
||||
// unix nano is the same if called in the same nanosecond, so we need to add another random seed
|
||||
source := rand.NewSource(time.Now().UnixNano() + int64(seed))
|
||||
|
@ -26,7 +26,7 @@ type Buffer struct {
|
||||
History []globals.Message `json:"history"`
|
||||
Images Images `json:"images"`
|
||||
ToolCalls *globals.ToolCalls `json:"tool_calls"`
|
||||
Charge Charge `json:"charge"`
|
||||
Charge Charge `json:"-"`
|
||||
}
|
||||
|
||||
func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
|
||||
|
@ -47,8 +47,13 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
||||
return form, err
|
||||
}
|
||||
|
||||
func UnmarshalString[T interface{}](data string) (form T, err error) {
|
||||
err = json.Unmarshal([]byte(data), &form)
|
||||
return form, err
|
||||
}
|
||||
|
||||
func UnmarshalForm[T interface{}](data string) *T {
|
||||
form, err := Unmarshal[T]([]byte(data))
|
||||
form, err := UnmarshalString[T](data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user