diff --git a/manager/chat_completions.go b/manager/chat_completions.go index 139fe8e..ed7152e 100644 --- a/manager/chat_completions.go +++ b/manager/chat_completions.go @@ -1,19 +1,22 @@ package manager import ( - "chat/adapter/common" + adaptercommon "chat/adapter/common" "chat/addition/web" "chat/admin" "chat/auth" "chat/channel" "chat/globals" "chat/utils" + "database/sql" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "strings" "time" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" ) const ( @@ -25,6 +28,14 @@ func supportRelayPlan() bool { return channel.SystemInstance.SupportRelayPlan() } +func checkEnableState(db *sql.DB, cache *redis.Client, user *auth.User, model string) (state error, plan bool) { + if supportRelayPlan() { + return auth.CanEnableModelWithSubscription(db, cache, user, model) + } + + return auth.CanEnableModel(db, user, model), false +} + func ChatRelayAPI(c *gin.Context) { if globals.CloseRelay { abortWithErrorResponse(c, fmt.Errorf("relay api is denied of access"), "access_denied_error") @@ -49,6 +60,7 @@ func ChatRelayAPI(c *gin.Context) { } db := utils.GetDBFromContext(c) + cache := utils.GetCacheFromContext(c) user := &auth.User{ Username: username, } @@ -68,20 +80,20 @@ func ChatRelayAPI(c *gin.Context) { form.Official = true } - check := auth.CanEnableModel(db, user, form.Model) + check, plan := checkEnableState(db, cache, user, form.Model) if check != nil { sendErrorResponse(c, check, "quota_exceeded_error") return } if form.Stream { - sendStreamTranshipmentResponse(c, form, messages, id, created, user, supportRelayPlan()) + sendStreamTranshipmentResponse(c, form, messages, id, created, user, plan) } else { - sendTranshipmentResponse(c, form, messages, id, created, user, supportRelayPlan()) + sendTranshipmentResponse(c, form, messages, id, created, user, plan) } } -func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adaptercommon.ChatProps { +func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps { return &adaptercommon.ChatProps{ Model: form.Model, Message: messages, @@ -103,7 +115,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals cache := utils.GetCacheFromContext(c) buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model)) - hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data *globals.Chunk) error { + hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer), func(data *globals.Chunk) error { buffer.WriteChunk(data) return nil }) @@ -212,7 +224,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g go func() { buffer := utils.NewBuffer(form.Model, messages, charge) hit, err := channel.NewChatRequestWithCache( - cache, buffer, group, getChatProps(form, messages, buffer, plan), + cache, buffer, group, getChatProps(form, messages, buffer), func(data *globals.Chunk) error { buffer.WriteChunk(data)