mirror of
https://github.com/coaidev/coai.git
synced 2025-05-31 02:40:32 +09:00
feat: support model cache buffer feature (support cache expiration
cache accepted models
cache size
)
This commit is contained in:
parent
b52f7b6617
commit
780a55a109
@ -44,7 +44,7 @@ import {
|
|||||||
} from "@/components/ui/dialog.tsx";
|
} from "@/components/ui/dialog.tsx";
|
||||||
import { DialogTitle } from "@radix-ui/react-dialog";
|
import { DialogTitle } from "@radix-ui/react-dialog";
|
||||||
import Require from "@/components/Require.tsx";
|
import Require from "@/components/Require.tsx";
|
||||||
import { Loader2 } from "lucide-react";
|
import { Loader2, Settings2 } from "lucide-react";
|
||||||
import { Textarea } from "@/components/ui/textarea.tsx";
|
import { Textarea } from "@/components/ui/textarea.tsx";
|
||||||
import Tips from "@/components/Tips.tsx";
|
import Tips from "@/components/Tips.tsx";
|
||||||
import { cn } from "@/components/ui/lib/utils.ts";
|
import { cn } from "@/components/ui/lib/utils.ts";
|
||||||
@ -52,6 +52,7 @@ import { Switch } from "@/components/ui/switch.tsx";
|
|||||||
import { MultiCombobox } from "@/components/ui/multi-combobox.tsx";
|
import { MultiCombobox } from "@/components/ui/multi-combobox.tsx";
|
||||||
import { allGroups } from "@/utils/groups.ts";
|
import { allGroups } from "@/utils/groups.ts";
|
||||||
import { channelModels } from "@/admin/channel.ts";
|
import { channelModels } from "@/admin/channel.ts";
|
||||||
|
import { supportModels } from "@/conf";
|
||||||
|
|
||||||
type CompProps<T> = {
|
type CompProps<T> = {
|
||||||
data: T;
|
data: T;
|
||||||
@ -592,6 +593,46 @@ function Common({ data, dispatch, onChange }: CompProps<CommonState>) {
|
|||||||
min={0}
|
min={0}
|
||||||
/>
|
/>
|
||||||
</ParagraphItem>
|
</ParagraphItem>
|
||||||
|
<ParagraphItem>
|
||||||
|
<div className={`flex flex-row flex-wrap gap-2 ml-auto`}>
|
||||||
|
<Button
|
||||||
|
variant={`outline`}
|
||||||
|
onClick={() => dispatch({ type: "update:common.cache", value: [] })}
|
||||||
|
>
|
||||||
|
<Settings2
|
||||||
|
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
|
||||||
|
/>
|
||||||
|
{t("admin.system.cacheNone")}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant={`outline`}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch({
|
||||||
|
type: "update:common.cache",
|
||||||
|
value: supportModels
|
||||||
|
.filter((item) => item.free)
|
||||||
|
.map((item) => item.id),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Settings2
|
||||||
|
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
|
||||||
|
/>
|
||||||
|
{t("admin.system.cacheFree")}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant={`outline`}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch({ type: "update:common.cache", value: channelModels })
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Settings2
|
||||||
|
className={`inline-flex h-4 w-4 mr-2 translate-y-[1px]`}
|
||||||
|
/>
|
||||||
|
{t("admin.system.cacheAll")}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</ParagraphItem>
|
||||||
<ParagraphSpace />
|
<ParagraphSpace />
|
||||||
<ParagraphItem>
|
<ParagraphItem>
|
||||||
<Label className={`flex flex-row items-center`}>
|
<Label className={`flex flex-row items-center`}>
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error {
|
func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error {
|
||||||
@ -28,3 +30,50 @@ func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) e
|
|||||||
globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
|
globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func PreflightCache(cache *redis.Client, hash string, buffer *utils.Buffer, hook globals.Hook) (int64, bool, error) {
|
||||||
|
idx := utils.Intn64(globals.CacheAcceptedSize)
|
||||||
|
key := fmt.Sprintf("chat-cache:%d:%s", idx, hash)
|
||||||
|
|
||||||
|
raw, err := cache.Get(cache.Context(), key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return idx, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := utils.UnmarshalString[utils.Buffer](raw)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return idx, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data := buf.Read()
|
||||||
|
if data == "" {
|
||||||
|
return idx, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.SetToolCalls(buf.GetToolCalls())
|
||||||
|
return idx, true, hook(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func StoreCache(cache *redis.Client, hash string, index int64, buffer *utils.Buffer) {
|
||||||
|
key := fmt.Sprintf("chat-cache:%d:%s", index, hash)
|
||||||
|
raw := utils.Marshal(buffer)
|
||||||
|
expire := time.Duration(globals.CacheAcceptedExpire) * time.Second
|
||||||
|
|
||||||
|
cache.Set(cache.Context(), key, raw, expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatRequestWithCache(cache *redis.Client, buffer *utils.Buffer, group string, props *adapter.ChatProps, hook globals.Hook) (bool, error) {
|
||||||
|
hash := utils.Md5Encrypt(utils.Marshal(props))
|
||||||
|
idx, hit, err := PreflightCache(cache, hash, buffer, hook)
|
||||||
|
if hit {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = NewChatRequest(group, props, hook); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
StoreCache(cache, hash, idx, buffer)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
@ -19,15 +19,16 @@ const defaultMessage = "empty response"
|
|||||||
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
|
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
|
||||||
db := utils.GetDBFromContext(c)
|
db := utils.GetDBFromContext(c)
|
||||||
quota := buffer.GetQuota()
|
quota := buffer.GetQuota()
|
||||||
if buffer.IsEmpty() {
|
|
||||||
return
|
if user == nil || quota <= 0 {
|
||||||
} else if buffer.GetCharge().IsBillingType(globals.TimesBilling) && err != nil {
|
|
||||||
// billing type is times, but error occurred
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// collect quota for tokens billing (though error occurred) or times billing
|
if buffer.IsEmpty() || err != nil {
|
||||||
if !uncountable && quota > 0 && user != nil {
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !uncountable {
|
||||||
user.UseQuota(db, quota)
|
user.UseQuota(db, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -92,7 +93,8 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
||||||
err := channel.NewChatRequest(
|
hit, err := channel.NewChatRequestWithCache(
|
||||||
|
cache, buffer,
|
||||||
auth.GetGroup(db, user),
|
auth.GetGroup(db, user),
|
||||||
&adapter.ChatProps{
|
&adapter.ChatProps{
|
||||||
Model: model,
|
Model: model,
|
||||||
@ -125,7 +127,6 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
|||||||
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||||
|
|
||||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||||
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
|
|
||||||
conn.Send(globals.ChatSegmentResponse{
|
conn.Send(globals.ChatSegmentResponse{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
End: true,
|
End: true,
|
||||||
@ -133,7 +134,9 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
|||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hit {
|
||||||
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
|
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
|
||||||
|
}
|
||||||
|
|
||||||
if buffer.IsEmpty() {
|
if buffer.IsEmpty() {
|
||||||
conn.Send(globals.ChatSegmentResponse{
|
conn.Send(globals.ChatSegmentResponse{
|
||||||
|
@ -89,7 +89,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
|
|||||||
cache := utils.GetCacheFromContext(c)
|
cache := utils.GetCacheFromContext(c)
|
||||||
|
|
||||||
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||||
err := channel.NewChatRequest(auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
||||||
buffer.Write(data)
|
buffer.Write(data)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -103,7 +103,10 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hit {
|
||||||
CollectQuota(c, user, buffer, plan, err)
|
CollectQuota(c, user, buffer, plan, err)
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, RelayResponse{
|
c.JSON(http.StatusOK, RelayResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", id),
|
Id: fmt.Sprintf("chatcmpl-%s", id),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@ -158,7 +161,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||||
err := channel.NewChatRequest(auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error {
|
||||||
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
|
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -173,7 +176,11 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
|
|||||||
}
|
}
|
||||||
|
|
||||||
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
|
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
|
||||||
|
|
||||||
|
if !hit {
|
||||||
CollectQuota(c, user, buffer, plan, err)
|
CollectQuota(c, user, buffer, plan, err)
|
||||||
|
}
|
||||||
|
|
||||||
close(partial)
|
close(partial)
|
||||||
return
|
return
|
||||||
}()
|
}()
|
||||||
|
@ -40,7 +40,8 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
||||||
err := channel.NewChatRequest(
|
hit, err := channel.NewChatRequestWithCache(
|
||||||
|
cache, buffer,
|
||||||
auth.GetGroup(db, user),
|
auth.GetGroup(db, user),
|
||||||
&adapter.ChatProps{
|
&adapter.ChatProps{
|
||||||
Model: model,
|
Model: model,
|
||||||
@ -56,11 +57,12 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
|||||||
admin.AnalysisRequest(model, buffer, err)
|
admin.AnalysisRequest(model, buffer, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||||
CollectQuota(c, user, buffer, plan, err)
|
|
||||||
return err.Error(), 0
|
return err.Error(), 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hit {
|
||||||
CollectQuota(c, user, buffer, plan, err)
|
CollectQuota(c, user, buffer, plan, err)
|
||||||
|
}
|
||||||
|
|
||||||
SaveCacheData(c, &CacheProps{
|
SaveCacheData(c, &CacheProps{
|
||||||
Message: segment,
|
Message: segment,
|
||||||
|
@ -89,7 +89,7 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||||
err := channel.NewChatRequest(auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data string) error {
|
hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data string) error {
|
||||||
buffer.Write(data)
|
buffer.Write(data)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -103,7 +103,9 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hit {
|
||||||
CollectQuota(c, user, buffer, plan, err)
|
CollectQuota(c, user, buffer, plan, err)
|
||||||
|
}
|
||||||
|
|
||||||
image := getUrlFromBuffer(buffer)
|
image := getUrlFromBuffer(buffer)
|
||||||
if image == "" {
|
if image == "" {
|
||||||
|
@ -13,6 +13,12 @@ func Intn(n int) int {
|
|||||||
return r.Intn(n)
|
return r.Intn(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Intn64(n int64) int64 {
|
||||||
|
source := rand.NewSource(time.Now().UnixNano())
|
||||||
|
r := rand.New(source)
|
||||||
|
return r.Int63n(n)
|
||||||
|
}
|
||||||
|
|
||||||
func IntnSeed(n int, seed int) int {
|
func IntnSeed(n int, seed int) int {
|
||||||
// unix nano is the same if called in the same nanosecond, so we need to add another random seed
|
// unix nano is the same if called in the same nanosecond, so we need to add another random seed
|
||||||
source := rand.NewSource(time.Now().UnixNano() + int64(seed))
|
source := rand.NewSource(time.Now().UnixNano() + int64(seed))
|
||||||
|
@ -26,7 +26,7 @@ type Buffer struct {
|
|||||||
History []globals.Message `json:"history"`
|
History []globals.Message `json:"history"`
|
||||||
Images Images `json:"images"`
|
Images Images `json:"images"`
|
||||||
ToolCalls *globals.ToolCalls `json:"tool_calls"`
|
ToolCalls *globals.ToolCalls `json:"tool_calls"`
|
||||||
Charge Charge `json:"charge"`
|
Charge Charge `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
|
func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
|
||||||
|
@ -47,8 +47,13 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
|||||||
return form, err
|
return form, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UnmarshalString[T interface{}](data string) (form T, err error) {
|
||||||
|
err = json.Unmarshal([]byte(data), &form)
|
||||||
|
return form, err
|
||||||
|
}
|
||||||
|
|
||||||
func UnmarshalForm[T interface{}](data string) *T {
|
func UnmarshalForm[T interface{}](data string) *T {
|
||||||
form, err := Unmarshal[T]([]byte(data))
|
form, err := UnmarshalString[T](data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user