mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 13:00:14 +09:00
feat: add input token and quota calculation for API request pre-checks
This commit is contained in:
parent
4d841c4a33
commit
9f6cec4298
@ -5,8 +5,9 @@ import (
|
|||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WebsocketGenerationForm struct {
|
type WebsocketGenerationForm struct {
|
||||||
@ -53,7 +54,7 @@ func GenerateAPI(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model)
|
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model, []globals.Message{})
|
||||||
if check != nil {
|
if check != nil {
|
||||||
conn.Send(globals.GenerationSegmentResponse{
|
conn.Send(globals.GenerationSegmentResponse{
|
||||||
Message: check.Error(),
|
Message: check.Error(),
|
||||||
|
21
auth/rule.go
21
auth/rule.go
@ -4,6 +4,10 @@ import (
|
|||||||
"chat/channel"
|
"chat/channel"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"chat/globals"
|
||||||
|
"chat/utils"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -11,10 +15,11 @@ const (
|
|||||||
ErrNotAuthenticated = "not authenticated error (model: %s)"
|
ErrNotAuthenticated = "not authenticated error (model: %s)"
|
||||||
ErrNotSetPrice = "the price of the model is not set (model: %s)"
|
ErrNotSetPrice = "the price of the model is not set (model: %s)"
|
||||||
ErrNotEnoughQuota = "user quota is not enough error (model: %s, minimum quota: %0.2f, your quota: %0.2f)"
|
ErrNotEnoughQuota = "user quota is not enough error (model: %s, minimum quota: %0.2f, your quota: %0.2f)"
|
||||||
|
ErrEstimatedCost = "estimated cost exceeds user quota (model: %s, estimated cost: %0.2f, your quota: %0.2f)"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CanEnableModel returns whether the model can be enabled (without subscription)
|
// CanEnableModel returns whether the model can be enabled (without subscription)
|
||||||
func CanEnableModel(db *sql.DB, user *User, model string) error {
|
func CanEnableModel(db *sql.DB, user *User, model string, messages []globals.Message) error {
|
||||||
isAuth := user != nil
|
isAuth := user != nil
|
||||||
isAdmin := isAuth && user.IsAdmin(db)
|
isAdmin := isAuth && user.IsAdmin(db)
|
||||||
|
|
||||||
@ -37,21 +42,23 @@ func CanEnableModel(db *sql.DB, user *User, model string) error {
|
|||||||
return fmt.Errorf(ErrNotAuthenticated, model)
|
return fmt.Errorf(ErrNotAuthenticated, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// return if the user is authenticated and has enough quota
|
// Calculate estimated input cost
|
||||||
limit := charge.GetLimit()
|
inputTokens := utils.NumTokensFromMessages(messages, model, false)
|
||||||
|
estimatedInputCost := float32(inputTokens) / 1000 * charge.GetInput()
|
||||||
|
|
||||||
|
// Get user's current quota
|
||||||
quota := user.GetQuota(db)
|
quota := user.GetQuota(db)
|
||||||
if quota < limit {
|
if quota < estimatedInputCost {
|
||||||
return fmt.Errorf(ErrNotEnoughQuota, model, limit, quota)
|
return fmt.Errorf(ErrEstimatedCost, model, estimatedInputCost, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string) (canEnable error, usePlan bool) {
|
func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string, messages []globals.Message) (canEnable error, usePlan bool) {
|
||||||
// use subscription quota first
|
// use subscription quota first
|
||||||
if user != nil && HandleSubscriptionUsage(db, cache, user, model) {
|
if user != nil && HandleSubscriptionUsage(db, cache, user, model) {
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
return CanEnableModel(db, user, model), false
|
return CanEnableModel(db, user, model, messages), false
|
||||||
}
|
}
|
||||||
|
@ -202,7 +202,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
|||||||
model := instance.GetModel()
|
model := instance.GetModel()
|
||||||
segment := adapter.ClearMessages(model, web.ToChatSearched(instance, restart))
|
segment := adapter.ClearMessages(model, web.ToChatSearched(instance, restart))
|
||||||
|
|
||||||
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model)
|
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model, segment)
|
||||||
conn.Send(globals.ChatSegmentResponse{
|
conn.Send(globals.ChatSegmentResponse{
|
||||||
Conversation: instance.GetId(),
|
Conversation: instance.GetId(),
|
||||||
})
|
})
|
||||||
|
@ -28,12 +28,12 @@ func supportRelayPlan() bool {
|
|||||||
return channel.SystemInstance.SupportRelayPlan()
|
return channel.SystemInstance.SupportRelayPlan()
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkEnableState(db *sql.DB, cache *redis.Client, user *auth.User, model string) (state error, plan bool) {
|
func checkEnableState(db *sql.DB, cache *redis.Client, user *auth.User, model string, messages []globals.Message) (state error, plan bool) {
|
||||||
if supportRelayPlan() {
|
if supportRelayPlan() {
|
||||||
return auth.CanEnableModelWithSubscription(db, cache, user, model)
|
return auth.CanEnableModelWithSubscription(db, cache, user, model, messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
return auth.CanEnableModel(db, user, model), false
|
return auth.CanEnableModel(db, user, model, messages), false
|
||||||
}
|
}
|
||||||
|
|
||||||
func ChatRelayAPI(c *gin.Context) {
|
func ChatRelayAPI(c *gin.Context) {
|
||||||
@ -80,7 +80,7 @@ func ChatRelayAPI(c *gin.Context) {
|
|||||||
form.Official = true
|
form.Official = true
|
||||||
}
|
}
|
||||||
|
|
||||||
check, plan := checkEnableState(db, cache, user, form.Model)
|
check, plan := checkEnableState(db, cache, user, form.Model, messages)
|
||||||
if check != nil {
|
if check != nil {
|
||||||
sendErrorResponse(c, check, "quota_exceeded_error")
|
sendErrorResponse(c, check, "quota_exceeded_error")
|
||||||
return
|
return
|
||||||
|
@ -28,7 +28,7 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
|||||||
|
|
||||||
db := utils.GetDBFromContext(c)
|
db := utils.GetDBFromContext(c)
|
||||||
cache := utils.GetCacheFromContext(c)
|
cache := utils.GetCacheFromContext(c)
|
||||||
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model)
|
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model, segment)
|
||||||
|
|
||||||
if check != nil {
|
if check != nil {
|
||||||
return check.Error(), 0
|
return check.Error(), 0
|
||||||
|
@ -54,7 +54,7 @@ func ImagesRelayAPI(c *gin.Context) {
|
|||||||
form.Model = strings.TrimSuffix(form.Model, "-official")
|
form.Model = strings.TrimSuffix(form.Model, "-official")
|
||||||
}
|
}
|
||||||
|
|
||||||
check := auth.CanEnableModel(db, user, form.Model)
|
check := auth.CanEnableModel(db, user, form.Model, []globals.Message{})
|
||||||
if check != nil {
|
if check != nil {
|
||||||
sendErrorResponse(c, check, "quota_exceeded_error")
|
sendErrorResponse(c, check, "quota_exceeded_error")
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user