feat: improve the billing checking system

This commit is contained in:
Zhang Minghan 2024-02-03 12:04:59 +08:00
parent 09dd77efe8
commit d0b5c90df9
10 changed files with 53 additions and 18 deletions

View File

@ -62,9 +62,9 @@ func GenerateAPI(c *gin.Context) {
} }
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model) check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model)
if !check { if check != nil {
conn.Send(globals.GenerationSegmentResponse{ conn.Send(globals.GenerationSegmentResponse{
Message: "You don't have enough quota to use this model.", Message: check.Error(),
Quota: 0, Quota: 0,
End: true, End: true,
}) })

View File

@ -37,6 +37,8 @@ export type PopupDialogProps = {
cancelLabel?: string; cancelLabel?: string;
confirmLabel?: string; confirmLabel?: string;
componentProps?: any;
}; };
type PopupFieldProps = PopupDialogProps & { type PopupFieldProps = PopupDialogProps & {
@ -50,6 +52,7 @@ function PopupField({
onValueChange, onValueChange,
value, value,
placeholder, placeholder,
componentProps,
}: PopupFieldProps) { }: PopupFieldProps) {
switch (type) { switch (type) {
case popupTypes.Text: case popupTypes.Text:
@ -62,6 +65,7 @@ function PopupField({
}} }}
value={value} value={value}
placeholder={placeholder} placeholder={placeholder}
{...componentProps}
/> />
); );
@ -72,6 +76,7 @@ function PopupField({
value={parseFloat(value)} value={parseFloat(value)}
onValueChange={(v) => setValue(v.toString())} onValueChange={(v) => setValue(v.toString())}
placeholder={placeholder} placeholder={placeholder}
{...componentProps}
/> />
); );
@ -82,6 +87,7 @@ function PopupField({
onCheckedChange={(state: boolean) => { onCheckedChange={(state: boolean) => {
setValue(state.toString()); setValue(state.toString());
}} }}
{...componentProps}
/> />
); );

View File

@ -137,6 +137,7 @@ function OperationMenu({ user, onRefresh }: OperationMenuProps) {
onValueChange={getNumber} onValueChange={getNumber}
open={quotaOpen} open={quotaOpen}
setOpen={setQuotaOpen} setOpen={setQuotaOpen}
componentProps={{ acceptNegative: true }}
onSubmit={async (value) => { onSubmit={async (value) => {
const quota = parseNumber(value); const quota = parseNumber(value);
const resp = await quotaOperation(user.id, quota); const resp = await quotaOperation(user.id, quota);
@ -155,6 +156,7 @@ function OperationMenu({ user, onRefresh }: OperationMenuProps) {
onValueChange={getNumber} onValueChange={getNumber}
open={quotaSetOpen} open={quotaSetOpen}
setOpen={setQuotaSetOpen} setOpen={setQuotaSetOpen}
componentProps={{ acceptNegative: true }}
onSubmit={async (value) => { onSubmit={async (value) => {
const quota = parseNumber(value); const quota = parseNumber(value);
const resp = await quotaOperation(user.id, quota, true); const resp = await quotaOperation(user.id, quota, true);

View File

@ -40,6 +40,8 @@ const NumberInput = React.forwardRef<HTMLInputElement, NumberInputProps>(
return v.match(exp)?.join("") || ""; return v.match(exp)?.join("") || "";
} }
if (v === "-" && props.acceptNegative) return v;
// replace -0124.5 to -124.5, 0043 to 43, 2.000 to 2.000 // replace -0124.5 to -124.5, 0043 to 43, 2.000 to 2.000
const exp = /^[-+]?0+(?=[0-9]+(\.[0-9]+)?$)/; const exp = /^[-+]?0+(?=[0-9]+(\.[0-9]+)?$)/;
v = v.replace(exp, ""); v = v.replace(exp, "");

View File

@ -3,27 +3,52 @@ package auth
import ( import (
"chat/channel" "chat/channel"
"database/sql" "database/sql"
"fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
) )
const (
ErrNotAuthenticated = "not authenticated error (model: %s)"
ErrNotSetPrice = "the price of the model is not set error (model: %s)"
ErrNotEnoughQuota = "user quota is not enough error (model: %s, minimum quota: %0.2f, your quota: %0.2f)"
)
// CanEnableModel returns whether the model can be enabled (without subscription) // CanEnableModel returns whether the model can be enabled (without subscription)
func CanEnableModel(db *sql.DB, user *User, model string) bool { func CanEnableModel(db *sql.DB, user *User, model string) error {
isAuth := user != nil isAuth := user != nil
charge := channel.ChargeInstance.GetCharge(model) charge := channel.ChargeInstance.GetCharge(model)
if !charge.IsBilling() { if !charge.IsBilling() {
// return if is the user is authenticated or anonymous is allowed for this model // return if is the user is authenticated or anonymous is allowed for this model
return charge.SupportAnonymous() || isAuth if charge.SupportAnonymous() || isAuth {
return nil
}
return fmt.Errorf(ErrNotAuthenticated, model)
}
if !isAuth {
return fmt.Errorf(ErrNotAuthenticated, model)
} }
// return if the user is authenticated and has enough quota // return if the user is authenticated and has enough quota
return isAuth && user.GetQuota(db) >= charge.GetLimit() limit := charge.GetLimit()
if limit == -1 {
return fmt.Errorf(ErrNotSetPrice, model)
}
quota := user.GetQuota(db)
if quota < limit {
return fmt.Errorf(ErrNotEnoughQuota, model, limit, quota)
}
return nil
} }
func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string) (canEnable bool, usePlan bool) { func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string) (canEnable error, usePlan bool) {
// use subscription quota first // use subscription quota first
if user != nil && HandleSubscriptionUsage(db, cache, user, model) { if user != nil && HandleSubscriptionUsage(db, cache, user, model) {
return true, true return nil, true
} }
return CanEnableModel(db, user, model), false return CanEnableModel(db, user, model), false
} }

View File

@ -283,7 +283,7 @@ func (c *Charge) GetLimit() float32 {
// 1k input tokens + 1k output tokens // 1k input tokens + 1k output tokens
return c.GetInput() + c.GetOutput() return c.GetInput() + c.GetOutput()
default: default:
return 0 return -1
} }
} }

View File

@ -15,7 +15,6 @@ import (
) )
const defaultMessage = "empty response" const defaultMessage = "empty response"
const defaultQuotaMessage = "You don't have enough quota or you don't have permission to use this model. please [buy](/buy) or [subscribe](/subscribe) to get more."
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) { func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
db := utils.GetDBFromContext(c) db := utils.GetDBFromContext(c)
@ -73,13 +72,14 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
Conversation: instance.GetId(), Conversation: instance.GetId(),
}) })
if !check { if check != nil {
message := check.Error()
conn.Send(globals.ChatSegmentResponse{ conn.Send(globals.ChatSegmentResponse{
Message: defaultQuotaMessage, Message: message,
Quota: 0, Quota: 0,
End: true, End: true,
}) })
return defaultQuotaMessage return message
} }
if form := ExtractCacheData(conn.GetCtx(), &CacheProps{ if form := ExtractCacheData(conn.GetCtx(), &CacheProps{

View File

@ -55,8 +55,8 @@ func ChatRelayAPI(c *gin.Context) {
} }
check := auth.CanEnableModel(db, user, form.Model) check := auth.CanEnableModel(db, user, form.Model)
if !check { if check != nil {
sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error") sendErrorResponse(c, check, "quota_exceeded_error")
return return
} }

View File

@ -27,8 +27,8 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
cache := utils.GetCacheFromContext(c) cache := utils.GetCacheFromContext(c)
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model)
if !check { if check != nil {
return defaultQuotaMessage, 0 return check.Error(), 0
} }
if form := ExtractCacheData(c, &CacheProps{ if form := ExtractCacheData(c, &CacheProps{

View File

@ -49,8 +49,8 @@ func ImagesRelayAPI(c *gin.Context) {
} }
check := auth.CanEnableModel(db, user, form.Model) check := auth.CanEnableModel(db, user, form.Model)
if !check { if check != nil {
sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error") sendErrorResponse(c, check, "quota_exceeded_error")
return return
} }