diff --git a/addition/generation/api.go b/addition/generation/api.go index 466768b..7128cf7 100644 --- a/addition/generation/api.go +++ b/addition/generation/api.go @@ -5,8 +5,9 @@ import ( "chat/globals" "chat/utils" "fmt" - "github.com/gin-gonic/gin" "strings" + + "github.com/gin-gonic/gin" ) type WebsocketGenerationForm struct { @@ -53,7 +54,7 @@ func GenerateAPI(c *gin.Context) { return } - check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model) + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model, []globals.Message{}) if check != nil { conn.Send(globals.GenerationSegmentResponse{ Message: check.Error(), diff --git a/auth/rule.go b/auth/rule.go index f4416a1..8cdb6c1 100644 --- a/auth/rule.go +++ b/auth/rule.go @@ -4,6 +4,10 @@ import ( "chat/channel" "database/sql" "fmt" + + "chat/globals" + "chat/utils" + "github.com/go-redis/redis/v8" ) @@ -11,10 +15,11 @@ const ( ErrNotAuthenticated = "not authenticated error (model: %s)" ErrNotSetPrice = "the price of the model is not set (model: %s)" ErrNotEnoughQuota = "user quota is not enough error (model: %s, minimum quota: %0.2f, your quota: %0.2f)" + ErrEstimatedCost = "estimated cost exceeds user quota (model: %s, estimated cost: %0.2f, your quota: %0.2f)" ) // CanEnableModel returns whether the model can be enabled (without subscription) -func CanEnableModel(db *sql.DB, user *User, model string) error { +func CanEnableModel(db *sql.DB, user *User, model string, messages []globals.Message) error { isAuth := user != nil isAdmin := isAuth && user.IsAdmin(db) @@ -37,21 +42,23 @@ func CanEnableModel(db *sql.DB, user *User, model string) error { return fmt.Errorf(ErrNotAuthenticated, model) } - // return if the user is authenticated and has enough quota - limit := charge.GetLimit() + // Calculate estimated input cost + inputTokens := utils.NumTokensFromMessages(messages, model, false) + estimatedInputCost := float32(inputTokens) / 1000 * charge.GetInput() + // Get user's current quota quota := user.GetQuota(db) - if quota < limit { - return fmt.Errorf(ErrNotEnoughQuota, model, limit, quota) + if quota < estimatedInputCost { + return fmt.Errorf(ErrEstimatedCost, model, estimatedInputCost, quota) } return nil } -func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string) (canEnable error, usePlan bool) { +func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string, messages []globals.Message) (canEnable error, usePlan bool) { // use subscription quota first if user != nil && HandleSubscriptionUsage(db, cache, user, model) { return nil, true } - return CanEnableModel(db, user, model), false + return CanEnableModel(db, user, model, messages), false } diff --git a/manager/chat.go b/manager/chat.go index 4185c62..43f2783 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -202,7 +202,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve model := instance.GetModel() segment := adapter.ClearMessages(model, web.ToChatSearched(instance, restart)) - check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model, segment) conn.Send(globals.ChatSegmentResponse{ Conversation: instance.GetId(), }) diff --git a/manager/chat_completions.go b/manager/chat_completions.go index 36ce60b..0aa30e5 100644 --- a/manager/chat_completions.go +++ b/manager/chat_completions.go @@ -28,12 +28,12 @@ func supportRelayPlan() bool { return channel.SystemInstance.SupportRelayPlan() } -func checkEnableState(db *sql.DB, cache *redis.Client, user *auth.User, model string) (state error, plan bool) { +func checkEnableState(db *sql.DB, cache *redis.Client, user *auth.User, model string, messages []globals.Message) (state error, plan bool) { if supportRelayPlan() { - return auth.CanEnableModelWithSubscription(db, cache, user, model) + return auth.CanEnableModelWithSubscription(db, cache, user, model, messages) } - return auth.CanEnableModel(db, user, model), false + return auth.CanEnableModel(db, user, model, messages), false } func ChatRelayAPI(c *gin.Context) { @@ -80,7 +80,7 @@ func ChatRelayAPI(c *gin.Context) { form.Official = true } - check, plan := checkEnableState(db, cache, user, form.Model) + check, plan := checkEnableState(db, cache, user, form.Model, messages) if check != nil { sendErrorResponse(c, check, "quota_exceeded_error") return diff --git a/manager/completions.go b/manager/completions.go index f0762df..c0fcc10 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -28,7 +28,7 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] db := utils.GetDBFromContext(c) cache := utils.GetCacheFromContext(c) - check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model, segment) if check != nil { return check.Error(), 0 diff --git a/manager/images.go b/manager/images.go index 5e80011..3317a1f 100644 --- a/manager/images.go +++ b/manager/images.go @@ -54,7 +54,7 @@ func ImagesRelayAPI(c *gin.Context) { form.Model = strings.TrimSuffix(form.Model, "-official") } - check := auth.CanEnableModel(db, user, form.Model) + check := auth.CanEnableModel(db, user, form.Model, []globals.Message{}) if check != nil { sendErrorResponse(c, check, "quota_exceeded_error") return