diff --git a/auth/rule.go b/auth/rule.go index bb00b9d..f4416a1 100644 --- a/auth/rule.go +++ b/auth/rule.go @@ -9,15 +9,21 @@ import ( const ( ErrNotAuthenticated = "not authenticated error (model: %s)" - ErrNotSetPrice = "the price of the model is not set error (model: %s)" + ErrNotSetPrice = "the price of the model is not set (model: %s)" ErrNotEnoughQuota = "user quota is not enough error (model: %s, minimum quota: %0.2f, your quota: %0.2f)" ) // CanEnableModel returns whether the model can be enabled (without subscription) func CanEnableModel(db *sql.DB, user *User, model string) error { isAuth := user != nil + isAdmin := isAuth && user.IsAdmin(db) + charge := channel.ChargeInstance.GetCharge(model) + if charge.IsUnsetType() && !isAdmin { + return fmt.Errorf(ErrNotSetPrice, model) + } + if !charge.IsBilling() { // return if is the user is authenticated or anonymous is allowed for this model if charge.SupportAnonymous() || isAuth { @@ -33,9 +39,6 @@ func CanEnableModel(db *sql.DB, user *User, model string) error { // return if the user is authenticated and has enough quota limit := charge.GetLimit() - if limit == -1 { - return fmt.Errorf(ErrNotSetPrice, model) - } quota := user.GetQuota(db) if quota < limit { diff --git a/channel/charge.go b/channel/charge.go index 460d807..710c168 100644 --- a/channel/charge.go +++ b/channel/charge.go @@ -74,6 +74,7 @@ func (m *ChargeManager) GetCharge(model string) *Charge { return &Charge{ Type: globals.NonBilling, Anonymous: false, + Unset: true, } } @@ -236,6 +237,10 @@ func (m *ChargeManager) GetRuleByModel(model string) *Charge { return nil } +func (c *Charge) IsUnsetType() bool { + return c.Unset +} + func (c *Charge) GetType() string { if c.Type == "" { return globals.NonBilling @@ -283,7 +288,7 @@ func (c *Charge) GetLimit() float32 { // 1k input tokens + 1k output tokens return c.GetInput() + c.GetOutput() default: - return -1 + return 0 } } diff --git a/channel/types.go b/channel/types.go index 4d36c66..d800663 100644 --- a/channel/types.go +++ b/channel/types.go @@ -39,6 +39,7 @@ type Charge struct { Input float32 `json:"input" mapstructure:"input"` Output float32 `json:"output" mapstructure:"output"` Anonymous bool `json:"anonymous" mapstructure:"anonymous"` + Unset bool `json:"-" mapstructure:"-"` } type ChargeSequence []*Charge