From 898b336afe8355de6c27c3943822deb3c0490142 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Tue, 19 Dec 2023 18:29:23 +0800 Subject: [PATCH] feat: update subscription rule --- app/src/api/types.ts | 14 ++ app/src/assets/pages/subscription.less | 8 + app/src/components/app/AppProvider.tsx | 2 - .../home/subscription/BuyDialog.tsx | 12 +- app/src/conf.ts | 22 ++- app/src/dialogs/SubscriptionDialog.tsx | 68 ++++---- app/src/i18n.ts | 24 +-- auth/plan.go | 153 ++++++++++++++---- auth/subscription.go | 12 ++ globals/usage.go | 3 +- main.go | 4 +- manager/transhipment.go | 7 +- middleware/middleware.go | 12 +- utils/cache.go | 8 + 14 files changed, 264 insertions(+), 85 deletions(-) diff --git a/app/src/api/types.ts b/app/src/api/types.ts index 5f5e3f1..b39160d 100644 --- a/app/src/api/types.ts +++ b/app/src/api/types.ts @@ -1,4 +1,5 @@ import { Conversation } from "./conversation.ts"; +import React from "react"; export type Message = { role: string; @@ -33,3 +34,16 @@ export type ConversationInstance = { }; export type ConversationMapper = Record; + +export type Plan = { + level: number; + price: number; +}; + +export type SubscriptionUsage = Record< + string, + { + icon: React.ReactElement; + name: string; + } +>; diff --git a/app/src/assets/pages/subscription.less b/app/src/assets/pages/subscription.less index a29f945..26dd435 100644 --- a/app/src/assets/pages/subscription.less +++ b/app/src/assets/pages/subscription.less @@ -106,6 +106,10 @@ font-size: 18px; font-weight: bold; margin: 2px auto; + + .tax { + color: hsl(var(--text-secondary)); + } } .annotate { @@ -152,6 +156,10 @@ margin-top: 12px; text-align: center; transform: translateY(12px); + + .tax { + color: hsl(var(--text-secondary)); + } } } diff --git a/app/src/components/app/AppProvider.tsx b/app/src/components/app/AppProvider.tsx index 0160aa6..5679d85 100644 --- a/app/src/components/app/AppProvider.tsx +++ b/app/src/components/app/AppProvider.tsx @@ -9,8 +9,6 @@ import { channelModels } from "@/admin/channel.ts"; function AppProvider() { useEffectAsync(async () => { - if (allModels.length !== 0) return; - const res = await axios.get("/v1/models"); res.data.forEach((model: string) => { if (!allModels.includes(model)) allModels.push(model); diff --git a/app/src/components/home/subscription/BuyDialog.tsx b/app/src/components/home/subscription/BuyDialog.tsx index 6f2ef22..377b47a 100644 --- a/app/src/components/home/subscription/BuyDialog.tsx +++ b/app/src/components/home/subscription/BuyDialog.tsx @@ -25,7 +25,7 @@ import { expiredSelector, refreshSubscription } from "@/store/subscription.ts"; import { Plus } from "lucide-react"; import { subscriptionPrize } from "@/conf.ts"; import { ToastAction } from "@/components/ui/toast.tsx"; -import { deeptrainEndpoint } from "@/utils/env.ts"; +import { deeptrainEndpoint, useDeeptrain } from "@/utils/env.ts"; function countPrize(base: number, month: number): number { const prize = subscriptionPrize[base] * month; @@ -160,6 +160,16 @@ export function Upgrade({ base, level }: UpgradeProps) {

{t("sub.price", { price: countPrize(base, month).toFixed(2) })} + + {useDeeptrain && ( + +   ( + {t("sub.price-tax", { + price: (countPrize(base, month) * 0.25).toFixed(1), + })} + ) + + )}

diff --git a/app/src/conf.ts b/app/src/conf.ts index a03379c..52d6d90 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -1,5 +1,5 @@ import axios from "axios"; -import { Model, PlanModel } from "@/api/types.ts"; +import { Model, PlanModel, SubscriptionUsage } from "@/api/types.ts"; import { deeptrainAppName, deeptrainEndpoint, @@ -9,6 +9,8 @@ import { getWebsocketApi, } from "@/utils/env.ts"; import { getMemory } from "@/utils/memory.ts"; +import { Compass, Image, Newspaper } from "lucide-react"; +import React from "react"; export const version = "3.7.6"; export const dev: boolean = getDev(); @@ -336,7 +338,7 @@ export const defaultModels = [ "stable-diffusion", ]; -export let allModels: string[] = []; +export let allModels: string[] = supportModels.map((model) => model.id); export const largeContextModels = [ "gpt-3.5-turbo-16k-0613", @@ -362,10 +364,10 @@ export const planModels: PlanModel[] = [ { id: "claude-2", level: 1 }, { id: "claude-2.1", level: 1 }, { id: "claude-2-100k", level: 1 }, - { id: "midjourney-fast", level: 2 }, + { id: "midjourney-fast", level: 1 }, ]; -export const expensiveModels = ["midjourney-turbo", "gpt-4-32k-0613"]; +export const expensiveModels = ["gpt-4-32k-0613"]; export const modelAvatars: Record = { "gpt-3.5-turbo-0613": "gpt35turbo.png", @@ -410,9 +412,15 @@ export const modelAvatars: Record = { }; export const subscriptionPrize: Record = { - 1: 18, - 2: 36, - 3: 72, + 1: 42, + 2: 76, + 3: 148, +}; + +export const subscriptionUsage: SubscriptionUsage = { + midjourney: { name: "Midjourney", icon: React.createElement(Image) }, + "gpt-4": { name: "GPT-4", icon: React.createElement(Compass) }, + "claude-100k": { name: "Claude 100k", icon: React.createElement(Newspaper) }, }; export function login() { diff --git a/app/src/dialogs/SubscriptionDialog.tsx b/app/src/dialogs/SubscriptionDialog.tsx index 3f49fd8..97f7e83 100644 --- a/app/src/dialogs/SubscriptionDialog.tsx +++ b/app/src/dialogs/SubscriptionDialog.tsx @@ -26,18 +26,17 @@ import { BookText, Calendar, Compass, - Image, ImagePlus, LifeBuoy, - Newspaper, ServerCrash, } from "lucide-react"; import { useEffectAsync } from "@/utils/hook.ts"; import { selectAuthenticated } from "@/store/auth.ts"; import SubscriptionUsage from "@/components/home/subscription/SubscriptionUsage.tsx"; import Tips from "@/components/Tips.tsx"; -import { subscriptionPrize } from "@/conf.ts"; +import { subscriptionPrize, subscriptionUsage } from "@/conf.ts"; import { Upgrade } from "@/components/home/subscription/BuyDialog.tsx"; +import { useDeeptrain } from "@/utils/env.ts"; function SubscriptionDialog() { const { t } = useTranslation(); @@ -47,7 +46,6 @@ function SubscriptionDialog() { const expired = useSelector(expiredSelector); const usage = useSelector(usageSelector); const auth = useSelector(selectAuthenticated); - const quota = useSelector(quotaDialogSelector); const dispatch = useDispatch(); @@ -84,21 +82,17 @@ function SubscriptionDialog() { name={t("sub.expired")} usage={expired} /> - } - name={"Midjourney"} - usage={usage?.["midjourney"]} - /> - } - name={"GPT-4"} - usage={usage?.["gpt-4"]} - /> - } - name={"Claude 100k"} - usage={usage?.["claude-100k"]} - /> + + {Object.entries(subscriptionUsage).map( + ([key, props], index) => + usage?.[key] && ( + + ), + )} )}
@@ -108,17 +102,25 @@ function SubscriptionDialog() {
{t("sub.plan-price", { money: subscriptionPrize[1] })}
-

({t("sub.include-tax")})

+ {useDeeptrain && ( +

({t("sub.include-tax")})

+ )}
- {t("sub.plan-gpt4", { times: 25 })} + {t("sub.plan-gpt4", { times: 150 })}
+
+ + {t("sub.plan-midjourney", { times: 50 })} + +
- {t("sub.plan-claude", { times: 50 })} + {t("sub.plan-claude", { times: 300 })} +
@@ -129,7 +131,9 @@ function SubscriptionDialog() {
{t("sub.plan-price", { money: subscriptionPrize[2] })}
-

({t("sub.include-tax")})

+ {useDeeptrain && ( +

({t("sub.include-tax")})

+ )}
@@ -138,17 +142,18 @@ function SubscriptionDialog() {
- {t("sub.plan-gpt4", { times: 50 })} + {t("sub.plan-gpt4", { times: 300 })}
- {t("sub.plan-midjourney", { times: 25 })} + {t("sub.plan-midjourney", { times: 100 })}
- {t("sub.plan-claude", { times: 100 })} + {t("sub.plan-claude", { times: 600 })} +
@@ -159,7 +164,9 @@ function SubscriptionDialog() {
{t("sub.plan-price", { money: subscriptionPrize[3] })}
-

({t("sub.include-tax")})

+ {useDeeptrain && ( +

({t("sub.include-tax")})

+ )}
@@ -168,17 +175,18 @@ function SubscriptionDialog() {
- {t("sub.plan-gpt4", { times: 100 })} + {t("sub.plan-gpt4", { times: 600 })}
- {t("sub.plan-midjourney", { times: 50 })} + {t("sub.plan-midjourney", { times: 200 })}
- {t("sub.plan-claude", { times: 200 })} + {t("sub.plan-claude", { times: 1200 })} +
diff --git a/app/src/i18n.ts b/app/src/i18n.ts index 86c0d6a..69e1ec9 100644 --- a/app/src/i18n.ts +++ b/app/src/i18n.ts @@ -158,11 +158,12 @@ const resources = { "free-conversation": "对话存储记录", "free-sharing": "对话分享功能", "free-api": "API 调用", - "plan-midjourney": "Midjourney 每日绘图 {{times}} 次", + "plan-midjourney": "Midjourney 每月绘图 {{times}} 次", "plan-midjourney-desc": "Midjourney 快速出图模式", - "plan-gpt4": "GPT-4 每日请求 {{times}} 次", + "plan-gpt4": "GPT-4 每月配额 {{times}} 次", "plan-gpt4-desc": "包含 GPT 4 Turbo, GPT 4V, GPT 4 DALLE", - "plan-claude": "Claude 100k 每日请求 {{times}} 次", + "plan-claude": "Claude 100k 每月配额 {{times}} 次", + "plan-claude-desc": "包含 Claude 2 (100k), Claude 2.1 (200k)", "pro-service": "优先服务支持", "pro-thread": "并发数提升", enterprise: "企业版", @@ -184,6 +185,7 @@ const resources = { "migrate-plan-desc": "变更订阅后,您的订阅时间将会根据剩余天数价格计算,重新计算订阅时间。(如降级会时间翻倍,升级会补齐差价)", price: "价格 {{price}} 元", + "price-tax": "含税 {{price}} 元", "upgrade-price": "升级费用 {{price}} 元 (仅供参考)", expired: "订阅剩余天数", time: { @@ -604,11 +606,12 @@ const resources = { "free-conversation": "conversation storage", "free-sharing": "conversation sharing", "free-api": "API calls", - "plan-midjourney": "Midjourney {{times}} image generation per day", + "plan-midjourney": "Midjourney {{times}} image generation per month", "plan-midjourney-desc": "Midjourney Quick Image Generation", - "plan-gpt4": "GPT-4 {{times}} requests per day", + "plan-gpt4": "GPT-4 {{times}} requests per month", "plan-gpt4-desc": "including GPT 4 Turbo, GPT 4V, GPT 4 DALLE", - "plan-claude": "Claude 100k {{times}} requests per day", + "plan-claude": "Claude 100k {{times}} requests per month", + "plan-claude-desc": "including Claude 2 (100k), Claude 2.1 (200k)", "pro-service": "Priority Service Support", "pro-thread": "Concurrency Increase", enterprise: "Enterprise", @@ -630,6 +633,7 @@ const resources = { "migrate-plan-desc": "After changing the subscription, your subscription time will be calculated based on the remaining days price, and the subscription time will be recalculated. (For example, downgrading will double the time, and upgrading will make up the difference)", price: "Price {{price}} CNY", + "price-tax": "Include Tax {{price}} CNY", "upgrade-price": "Upgrade Fee {{price}} CNY (for reference only)", expired: "Subscription Remaining Days", time: { @@ -1068,11 +1072,12 @@ const resources = { "free-conversation": "хранение разговоров", "free-sharing": "общий доступ к разговорам", "free-api": "API вызовы", - "plan-midjourney": "Midjourney {{times}} генерация изображений в день", + "plan-midjourney": "Midjourney {{times}} генерация изображений в месяц", "plan-midjourney-desc": "Быстрая генерация изображений Midjourney", - "plan-gpt4": "GPT-4 {{times}} запросов в день", + "plan-gpt4": "GPT-4 {{times}} запросов в месяц", "plan-gpt4-desc": "включая GPT 4 Turbo, GPT 4V, GPT 4 DALLE", - "plan-claude": "Claude 100k {{times}} запросов в день", + "plan-claude": "Claude 100k {{times}} запросов в месяц", + "plan-claude-desc": "включая Claude 2 (100k), Claude 2.1 (200k)", "pro-service": "Приоритетная служба поддержки", "pro-thread": "Увеличение параллелизма", enterprise: "Корпоративный", @@ -1094,6 +1099,7 @@ const resources = { "migrate-plan-desc": "После изменения подписки ваше время подписки будет рассчитываться на основе цены оставшихся дней, и время подписки будет пересчитано. (Например, понижение удваивает время, а повышение компенсирует разницу)", price: "Цена {{price}} CNY", + "price-tax": "Включая налог {{price}} CNY", "upgrade-price": "Плата за обновление {{price}} CNY (для справки)", expired: "Осталось дней подписки", time: { diff --git a/auth/plan.go b/auth/plan.go index b271f95..60f1d75 100644 --- a/auth/plan.go +++ b/auth/plan.go @@ -4,7 +4,11 @@ import ( "chat/globals" "chat/utils" "database/sql" + "errors" + "fmt" "github.com/go-redis/redis/v8" + "strings" + "time" ) type Plan struct { @@ -33,44 +37,121 @@ var Plans = []Plan{ }, { Level: 1, - Price: 18, + Price: 42, Usage: []PlanUsage{ - {Id: "gpt-4", Value: 25, Including: globals.IsGPT4NativeModel}, - {Id: "claude-100k", Value: 50, Including: globals.IsClaude100KModel}, - }, - }, - { - Level: 2, - Price: 36, - Usage: []PlanUsage{ - {Id: "gpt-4", Value: 50, Including: globals.IsGPT4NativeModel}, - {Id: "claude-100k", Value: 100, Including: globals.IsClaude100KModel}, - {Id: "midjourney", Value: 25, Including: globals.IsMidjourneyFastModel}, - }, - }, - { - Level: 3, - Price: 72, - Usage: []PlanUsage{ - {Id: "gpt-4", Value: 100, Including: globals.IsGPT4NativeModel}, - {Id: "claude-100k", Value: 200, Including: globals.IsClaude100KModel}, + {Id: "gpt-4", Value: 150, Including: globals.IsGPT4NativeModel}, + {Id: "claude-100k", Value: 300, Including: globals.IsClaude100KModel}, {Id: "midjourney", Value: 50, Including: globals.IsMidjourneyFastModel}, }, }, { - // enterprise - Level: 4, - Price: 999, - Usage: []PlanUsage{}, + Level: 2, + Price: 76, + Usage: []PlanUsage{ + {Id: "gpt-4", Value: 300, Including: globals.IsGPT4NativeModel}, + {Id: "claude-100k", Value: 600, Including: globals.IsClaude100KModel}, + {Id: "midjourney", Value: 100, Including: globals.IsMidjourneyFastModel}, + }, + }, + { + Level: 3, + Price: 148, + Usage: []PlanUsage{ + {Id: "gpt-4", Value: 100, Including: globals.IsGPT4NativeModel}, + {Id: "claude-100k", Value: 1200, Including: globals.IsClaude100KModel}, + {Id: "midjourney", Value: 200, Including: globals.IsMidjourneyFastModel}, + }, }, } +var planExp int64 = 0 + +func getOffsetFormat(offset time.Time, usage int64) string { + return fmt.Sprintf("%s/%d", offset.Format("2006-01-02:15:04:05"), usage) +} + +func GetSubscriptionUsage(cache *redis.Client, user *User, t string) (usage int64, offset time.Time) { + // example cache value: 2021-09-01:19:00:00/100 + // if date is longer than 1 month, reset usage + + offset = time.Now() + + key := globals.GetSubscriptionLimitFormat(t, user.ID) + v, err := utils.GetCache(cache, key) + if (err != nil && errors.Is(err, redis.Nil)) || len(v) == 0 { + usage = 0 + } + + seg := strings.Split(v, "/") + if len(seg) != 2 { + usage = 0 + } else { + date, err := time.Parse("2006-01-02:15:04:05", seg[0]) + usage = utils.ParseInt64(seg[1]) + if err != nil { + usage = 0 + } + + // check if date is longer than current date after 1 month, if true, reset usage + + if date.AddDate(0, 1, 0).Before(time.Now()) { + // date is longer than 1 month, reset usage + usage = 0 + + // get current date offset (1 month step) + // example: 2021-09-01:19:00:0/100 -> 2021-10-01:19:00:00/100 + + // copy date to offset + offset = date + + // example: + // current time: 2021-09-08:14:00:00 + // offset: 2021-07-01:19:00:00 + // expected offset: 2021-09-01:19:00:00 + // offset is not longer than current date, stop adding 1 month + + for offset.AddDate(0, 1, 0).Before(time.Now()) { + offset = offset.AddDate(0, 1, 0) + } + } else { + // date is not longer than 1 month, use current date value + + offset = date + } + } + + // set new cache value + _ = utils.SetCache(cache, key, getOffsetFormat(offset, usage), planExp) + + return +} + func IncreaseSubscriptionUsage(cache *redis.Client, user *User, t string, limit int64) bool { - return utils.IncrWithLimit(cache, globals.GetSubscriptionLimitFormat(t, user.ID), 1, limit, 60*60*24) // 1 day + key := globals.GetSubscriptionLimitFormat(t, user.ID) + usage, offset := GetSubscriptionUsage(cache, user, t) + + usage += 1 + if usage > limit { + return false + } + + // set new cache value + err := utils.SetCache(cache, key, getOffsetFormat(offset, usage), planExp) + return err == nil } func DecreaseSubscriptionUsage(cache *redis.Client, user *User, t string) bool { - return utils.DecrInt(cache, globals.GetSubscriptionLimitFormat(t, user.ID), 1) + key := globals.GetSubscriptionLimitFormat(t, user.ID) + usage, offset := GetSubscriptionUsage(cache, user, t) + + usage -= 1 + if usage < 0 { + return true + } + + // set new cache value + err := utils.SetCache(cache, key, getOffsetFormat(offset, usage), planExp) + return err == nil } func (p *Plan) GetUsage(user *User, db *sql.DB, cache *redis.Client) UsageMap { @@ -80,7 +161,25 @@ func (p *Plan) GetUsage(user *User, db *sql.DB, cache *redis.Client) UsageMap { } func (p *PlanUsage) GetUsage(user *User, db *sql.DB, cache *redis.Client) int64 { - return utils.MustInt(cache, globals.GetSubscriptionLimitFormat(p.Id, user.GetID(db))) + // preflight check + user.GetID(db) + usage, _ := GetSubscriptionUsage(cache, user, p.Id) + return usage +} + +func (p *PlanUsage) ResetUsage(user *User, cache *redis.Client) bool { + key := globals.GetSubscriptionLimitFormat(p.Id, user.ID) + _, offset := GetSubscriptionUsage(cache, user, p.Id) + + err := utils.SetCache(cache, key, getOffsetFormat(offset, 0), planExp) + return err == nil +} + +func (p *PlanUsage) CreateUsage(user *User, cache *redis.Client) bool { + key := globals.GetSubscriptionLimitFormat(p.Id, user.ID) + + err := utils.SetCache(cache, key, getOffsetFormat(time.Now(), 0), planExp) + return err == nil } func (p *PlanUsage) GetUsageForm(user *User, db *sql.DB, cache *redis.Client) Usage { diff --git a/auth/subscription.go b/auth/subscription.go index 52e63fc..6682711 100644 --- a/auth/subscription.go +++ b/auth/subscription.go @@ -139,7 +139,19 @@ func BuySubscription(db *sql.DB, cache *redis.Client, user *User, level int, mon // buy new subscription or renew subscription money := CountSubscriptionPrize(level, month) if user.Pay(cache, money) { + // migrate subscription user.AddSubscription(db, month, level) + + if before == 0 { + // new subscription + + plan := user.GetPlan(db) + for _, usage := range plan.Usage { + // create usage + usage.CreateUsage(user, cache) + } + } + return true } } else if before > level { diff --git a/globals/usage.go b/globals/usage.go index d7fd24b..d57ef08 100644 --- a/globals/usage.go +++ b/globals/usage.go @@ -2,9 +2,8 @@ package globals import ( "fmt" - "time" ) func GetSubscriptionLimitFormat(t string, id int64) string { - return fmt.Sprintf(":subscription-usage-%s:%s:%d", t, time.Now().Format("2006-01-02"), id) + return fmt.Sprintf("usage-%s:%d", t, id) } diff --git a/main.go b/main.go index 7b30163..693cf3a 100644 --- a/main.go +++ b/main.go @@ -26,7 +26,9 @@ func main() { channel.InitManager() app := gin.New() - middleware.RegisterMiddleware(app) + + worker := middleware.RegisterMiddleware(app) + defer worker() { auth.Register(app) diff --git a/manager/transhipment.go b/manager/transhipment.go index 7fc7771..ea8b8fb 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -124,7 +124,6 @@ func TranshipmentAPI(c *gin.Context) { } db := utils.GetDBFromContext(c) - cache := utils.GetCacheFromContext(c) user := &auth.User{ Username: username, } @@ -143,16 +142,16 @@ func TranshipmentAPI(c *gin.Context) { form.Official = true } - check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model) + check := auth.CanEnableModel(db, user, form.Model) if !check { sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error") return } if form.Stream { - sendStreamTranshipmentResponse(c, form, id, created, user, plan) + sendStreamTranshipmentResponse(c, form, id, created, user, false) } else { - sendTranshipmentResponse(c, form, id, created, user, plan) + sendTranshipmentResponse(c, form, id, created, user, false) } } diff --git a/middleware/middleware.go b/middleware/middleware.go index 30d9734..a9d2721 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -5,9 +5,17 @@ import ( "github.com/gin-gonic/gin" ) -func RegisterMiddleware(app *gin.Engine) { +func RegisterMiddleware(app *gin.Engine) func() { + db := connection.InitMySQLSafe() + cache := connection.InitRedisSafe() + app.Use(CORSMiddleware()) - app.Use(BuiltinMiddleWare(connection.InitMySQLSafe(), connection.InitRedisSafe())) + app.Use(BuiltinMiddleWare(db, cache)) app.Use(ThrottleMiddleware()) app.Use(AuthMiddleware()) + + return func() { + db.Close() + cache.Close() + } } diff --git a/utils/cache.go b/utils/cache.go index d5139e8..fdf4134 100644 --- a/utils/cache.go +++ b/utils/cache.go @@ -45,6 +45,14 @@ func GetJson[T any](cache *redis.Client, key string) *T { return UnmarshalForm[T](val) } +func GetCache(cache *redis.Client, key string) (string, error) { + return cache.Get(context.Background(), key).Result() +} + +func SetCache(cache *redis.Client, key string, value string, expiration int64) error { + return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err() +} + func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool { // not exist if _, err := cache.Get(context.Background(), key).Result(); err != nil {