feat: support model cache buffer feature (support cache expiration cache accepted models cache size)

This commit is contained in:
Zhang Minghan 2024-02-13 15:08:03 +08:00
parent b52f7b6617
commit 780a55a109
9 changed files with 136 additions and 21 deletions

View File

@ -44,7 +44,7 @@ import {
} from "@/components/ui/dialog.tsx";
import { DialogTitle } from "@radix-ui/react-dialog";
import Require from "@/components/Require.tsx";
import { Loader2 } from "lucide-react";
import { Loader2, Settings2 } from "lucide-react";
import { Textarea } from "@/components/ui/textarea.tsx";
import Tips from "@/components/Tips.tsx";
import { cn } from "@/components/ui/lib/utils.ts";
@ -52,6 +52,7 @@ import { Switch } from "@/components/ui/switch.tsx";
import { MultiCombobox } from "@/components/ui/multi-combobox.tsx";
import { allGroups } from "@/utils/groups.ts";
import { channelModels } from "@/admin/channel.ts";
import { supportModels } from "@/conf";
type CompProps<T> = {
data: T;
@ -592,6 +593,46 @@ function Common({ data, dispatch, onChange }: CompProps<CommonState>) {
min={0}
/>
</ParagraphItem>
<ParagraphItem>
<div className={`flex flex-row flex-wrap gap-2 ml-auto`}>
<Button
variant={`outline`}
onClick={() => dispatch({ type: "update:common.cache", value: [] })}
>
<Settings2
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
/>
{t("admin.system.cacheNone")}
</Button>
<Button
variant={`outline`}
onClick={() =>
dispatch({
type: "update:common.cache",
value: supportModels
.filter((item) => item.free)
.map((item) => item.id),
})
}
>
<Settings2
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
/>
{t("admin.system.cacheFree")}
</Button>
<Button
variant={`outline`}
onClick={() =>
dispatch({ type: "update:common.cache", value: channelModels })
}
>
<Settings2
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
/>
{t("admin.system.cacheAll")}
</Button>
</div>
</ParagraphItem>
<ParagraphSpace />
<ParagraphItem>
<Label className={`flex flex-row items-center`}>

View File

@ -5,6 +5,8 @@ import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/go-redis/redis/v8"
"time"
)
func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error {
@ -28,3 +30,50 @@ func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) e
globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
return err
}
func PreflightCache(cache *redis.Client, hash string, buffer *utils.Buffer, hook globals.Hook) (int64, bool, error) {
idx := utils.Intn64(globals.CacheAcceptedSize)
key := fmt.Sprintf("chat-cache:%d:%s", idx, hash)
raw, err := cache.Get(cache.Context(), key).Result()
if err != nil {
return idx, false, nil
}
buf, err := utils.UnmarshalString[utils.Buffer](raw)
if err != nil {
fmt.Println(err)
return idx, false, nil
}
data := buf.Read()
if data == "" {
return idx, false, nil
}
buffer.SetToolCalls(buf.GetToolCalls())
return idx, true, hook(data)
}
func StoreCache(cache *redis.Client, hash string, index int64, buffer *utils.Buffer) {
key := fmt.Sprintf("chat-cache:%d:%s", index, hash)
raw := utils.Marshal(buffer)
expire := time.Duration(globals.CacheAcceptedExpire) * time.Second
cache.Set(cache.Context(), key, raw, expire)
}
func NewChatRequestWithCache(cache *redis.Client, buffer *utils.Buffer, group string, props *adapter.ChatProps, hook globals.Hook) (bool, error) {
hash := utils.Md5Encrypt(utils.Marshal(props))
idx, hit, err := PreflightCache(cache, hash, buffer, hook)
if hit {
return true, err
}
if err = NewChatRequest(group, props, hook); err != nil {
return false, err
}
StoreCache(cache, hash, idx, buffer)
return false, nil
}

View File

@ -19,15 +19,16 @@ const defaultMessage = "empty response"
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
db := utils.GetDBFromContext(c)
quota := buffer.GetQuota()
if buffer.IsEmpty() {
return
} else if buffer.GetCharge().IsBillingType(globals.TimesBilling) && err != nil {
// billing type is times, but error occurred
if user == nil || quota <= 0 {
return
}
// collect quota for tokens billing (though error occurred) or times billing
if !uncountable && quota > 0 && user != nil {
if buffer.IsEmpty() || err != nil {
return
}
if !uncountable {
user.UseQuota(db, quota)
}
}
@ -92,7 +93,8 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
}
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
err := channel.NewChatRequest(
hit, err := channel.NewChatRequestWithCache(
cache, buffer,
auth.GetGroup(db, user),
&adapter.ChatProps{
Model: model,
@ -125,7 +127,6 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
auth.RevertSubscriptionUsage(db, cache, user, model)
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
conn.Send(globals.ChatSegmentResponse{
Message: err.Error(),
End: true,
@ -133,7 +134,9 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
return err.Error()
}
if !hit {
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
}
if buffer.IsEmpty() {
conn.Send(globals.ChatSegmentResponse{

View File

@ -89,7 +89,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
cache := utils.GetCacheFromContext(c)
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
err := channel.NewChatRequest(auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
buffer.Write(data)
return nil
})
@ -103,7 +103,10 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
return
}
if !hit {
CollectQuota(c, user, buffer, plan, err)
}
c.JSON(http.StatusOK, RelayResponse{
Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion",
@ -158,7 +161,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
go func() {
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
err := channel.NewChatRequest(auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
return nil
})
@ -173,7 +176,11 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
}
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
if !hit {
CollectQuota(c, user, buffer, plan, err)
}
close(partial)
return
}()

View File

@ -40,7 +40,8 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
}
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
err := channel.NewChatRequest(
hit, err := channel.NewChatRequestWithCache(
cache, buffer,
auth.GetGroup(db, user),
&adapter.ChatProps{
Model: model,
@ -56,11 +57,12 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
admin.AnalysisRequest(model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, model)
CollectQuota(c, user, buffer, plan, err)
return err.Error(), 0
}
if !hit {
CollectQuota(c, user, buffer, plan, err)
}
SaveCacheData(c, &CacheProps{
Message: segment,

View File

@ -89,7 +89,7 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
}
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
err := channel.NewChatRequest(auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data string) error {
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data string) error {
buffer.Write(data)
return nil
})
@ -103,7 +103,9 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
return
}
if !hit {
CollectQuota(c, user, buffer, plan, err)
}
image := getUrlFromBuffer(buffer)
if image == "" {

View File

@ -13,6 +13,12 @@ func Intn(n int) int {
return r.Intn(n)
}
func Intn64(n int64) int64 {
source := rand.NewSource(time.Now().UnixNano())
r := rand.New(source)
return r.Int63n(n)
}
func IntnSeed(n int, seed int) int {
// unix nano is the same if called in the same nanosecond, so we need to add another random seed
source := rand.NewSource(time.Now().UnixNano() + int64(seed))

View File

@ -26,7 +26,7 @@ type Buffer struct {
History []globals.Message `json:"history"`
Images Images `json:"images"`
ToolCalls *globals.ToolCalls `json:"tool_calls"`
Charge Charge `json:"charge"`
Charge Charge `json:"-"`
}
func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {

View File

@ -47,8 +47,13 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
return form, err
}
func UnmarshalString[T interface{}](data string) (form T, err error) {
err = json.Unmarshal([]byte(data), &form)
return form, err
}
func UnmarshalForm[T interface{}](data string) *T {
form, err := Unmarshal[T]([]byte(data))
form, err := UnmarshalString[T](data)
if err != nil {
return nil
}