mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
203 lines
5.2 KiB
Go
203 lines
5.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"chat/channel"
|
|
"chat/globals"
|
|
"chat/utils"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-redis/redis/v8"
|
|
"math"
|
|
"time"
|
|
)
|
|
|
|
func disableSubscription() bool {
|
|
return !channel.PlanInstance.IsEnabled()
|
|
}
|
|
|
|
func (u *User) GetSubscription(db *sql.DB) (time.Time, int) {
|
|
if u.Subscription != nil && u.Subscription.Unix() > 0 {
|
|
return *u.Subscription, u.Level
|
|
}
|
|
|
|
var expiredAt []uint8
|
|
if err := globals.QueryRowDb(db, "SELECT expired_at, level FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&expiredAt, &u.Level); err != nil {
|
|
return time.Unix(0, 0), 0
|
|
}
|
|
|
|
t := utils.ConvertTime(expiredAt)
|
|
if t == nil {
|
|
t = utils.ToPtr(time.Unix(0, 0))
|
|
}
|
|
|
|
u.Subscription = t
|
|
return *u.Subscription, u.Level
|
|
}
|
|
|
|
func (u *User) GetSubscriptionLevel(db *sql.DB) int {
|
|
_, level := u.GetSubscription(db)
|
|
if !u.IsSubscribe(db) {
|
|
return 0
|
|
}
|
|
return level
|
|
}
|
|
|
|
func (u *User) GetPlan(db *sql.DB) channel.Plan {
|
|
return channel.PlanInstance.GetPlan(u.GetSubscriptionLevel(db))
|
|
}
|
|
|
|
func (u *User) GetSubscriptionExpiredAt(db *sql.DB) time.Time {
|
|
stamp, _ := u.GetSubscription(db)
|
|
return stamp
|
|
}
|
|
|
|
func (u *User) GetSubscriptionTime(db *sql.DB) time.Time {
|
|
stamp, _ := u.GetSubscription(db)
|
|
return stamp
|
|
}
|
|
|
|
func (u *User) IsSubscribe(db *sql.DB) bool {
|
|
stamp, level := u.GetSubscription(db)
|
|
return stamp.Unix() > time.Now().Unix() && level > 0
|
|
}
|
|
|
|
func (u *User) IsEnterprise(db *sql.DB) bool {
|
|
if !u.IsSubscribe(db) {
|
|
return false
|
|
}
|
|
|
|
var enterprise sql.NullBool
|
|
if err := globals.QueryRowDb(db, "SELECT enterprise FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&enterprise); err != nil {
|
|
return false
|
|
}
|
|
|
|
return enterprise.Valid && enterprise.Bool
|
|
}
|
|
|
|
func (u *User) GetSubscriptionExpiredDay(db *sql.DB) int {
|
|
stamp := u.GetSubscriptionTime(db).Sub(time.Now())
|
|
return int(math.Round(stamp.Hours() / 24))
|
|
}
|
|
|
|
func (u *User) AddSubscription(db *sql.DB, month int, level int) bool {
|
|
current := u.GetSubscriptionTime(db)
|
|
if current.Unix() < time.Now().Unix() {
|
|
current = time.Now()
|
|
}
|
|
expiredAt := current.AddDate(0, month, 0)
|
|
date := utils.ConvertSqlTime(expiredAt)
|
|
_, err := globals.ExecDb(db, `
|
|
INSERT INTO subscription (user_id, expired_at, total_month, level) VALUES (?, ?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE expired_at = ?, total_month = total_month + ?, level = ?
|
|
`, u.GetID(db), date, month, level, date, month, level)
|
|
return err == nil
|
|
}
|
|
|
|
func (u *User) DowngradePlan(db *sql.DB, target int) error {
|
|
expired, current := u.GetSubscription(db)
|
|
if current == 0 || current == target {
|
|
return fmt.Errorf("invalid plan level")
|
|
}
|
|
|
|
now := time.Now()
|
|
weight := channel.PlanInstance.GetPlan(current).Price / channel.PlanInstance.GetPlan(target).Price
|
|
stamp := float32(expired.Unix()-now.Unix()) * weight
|
|
|
|
// ceil expired time
|
|
expiredAt := now.Add(time.Duration(stamp)*time.Second).AddDate(0, 0, -1)
|
|
date := utils.ConvertSqlTime(expiredAt)
|
|
_, err := globals.ExecDb(db, "UPDATE subscription SET level = ?, expired_at = ? WHERE user_id = ?", target, date, u.GetID(db))
|
|
|
|
return err
|
|
}
|
|
|
|
func (u *User) CountUpgradePrice(db *sql.DB, target int) float32 {
|
|
expired := u.GetSubscriptionExpiredAt(db)
|
|
weight := channel.PlanInstance.GetPlan(target).Price - u.GetPlan(db).Price
|
|
if weight < 0 {
|
|
return 0
|
|
}
|
|
|
|
days := expired.Sub(time.Now()).Hours() / 24
|
|
return float32(days) * weight / 30
|
|
}
|
|
|
|
func (u *User) SetSubscriptionLevel(db *sql.DB, level int) bool {
|
|
_, err := globals.ExecDb(db, "UPDATE subscription SET level = ? WHERE user_id = ?", level, u.GetID(db))
|
|
return err == nil
|
|
}
|
|
|
|
func CountSubscriptionPrize(level int, month int) float32 {
|
|
plan := channel.PlanInstance.GetPlan(level)
|
|
base := plan.Price * float32(month)
|
|
if month >= 36 {
|
|
return base * 0.7
|
|
} else if month >= 12 {
|
|
return base * 0.8
|
|
} else if month >= 6 {
|
|
return base * 0.9
|
|
}
|
|
return base
|
|
}
|
|
|
|
func BuySubscription(db *sql.DB, cache *redis.Client, user *User, level int, month int) error {
|
|
if disableSubscription() {
|
|
return errors.New("subscription feature does not enable of this site")
|
|
}
|
|
|
|
if month < 1 || month > 999 || !channel.IsValidPlan(level) {
|
|
return errors.New("invalid subscription params")
|
|
}
|
|
|
|
before := user.GetSubscriptionLevel(db)
|
|
if before == 0 || before == level {
|
|
// buy new subscription or renew subscription
|
|
money := CountSubscriptionPrize(level, month)
|
|
if user.Pay(db, cache, money) {
|
|
// migrate subscription
|
|
user.AddSubscription(db, month, level)
|
|
|
|
if before == 0 {
|
|
// new subscription
|
|
|
|
plan := user.GetPlan(db)
|
|
for _, usage := range plan.Items {
|
|
// create usage
|
|
usage.CreateUsage(user, cache)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
} else if before > level {
|
|
// downgrade subscription
|
|
return user.DowngradePlan(db, level)
|
|
} else {
|
|
// upgrade subscription
|
|
money := user.CountUpgradePrice(db, level)
|
|
if user.Pay(db, cache, money) {
|
|
user.SetSubscriptionLevel(db, level)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return errors.New("not enough money")
|
|
}
|
|
|
|
func HandleSubscriptionUsage(db *sql.DB, cache *redis.Client, user *User, model string) bool {
|
|
if disableSubscription() {
|
|
return false
|
|
}
|
|
plan := user.GetPlan(db)
|
|
return plan.IncreaseUsage(user, cache, model)
|
|
}
|
|
|
|
func RevertSubscriptionUsage(db *sql.DB, cache *redis.Client, user *User, model string) bool {
|
|
if disableSubscription() {
|
|
return false
|
|
}
|
|
plan := user.GetPlan(db)
|
|
return plan.DecreaseUsage(user, cache, model)
|
|
}
|