diff --git a/app/src/api/types.ts b/app/src/api/types.ts index 91851e6..4ee473f 100644 --- a/app/src/api/types.ts +++ b/app/src/api/types.ts @@ -1,5 +1,4 @@ import { Conversation } from "./conversation.ts"; -import React from "react"; export type Message = { role: string; @@ -49,11 +48,3 @@ export type Plan = { }; export type Plans = Plan[]; - -export type SubscriptionUsage = Record< - string, - { - icon: React.ReactElement; - name: string; - } ->; diff --git a/app/src/api/v1.ts b/app/src/api/v1.ts index 5022da9..4a71f74 100644 --- a/app/src/api/v1.ts +++ b/app/src/api/v1.ts @@ -1,5 +1,5 @@ import axios from "axios"; -import { Model } from "@/api/types.ts"; +import { Model, Plan } from "@/api/types.ts"; import { ChargeProps } from "@/admin/charge.ts"; export async function getApiModels(): Promise { @@ -12,6 +12,16 @@ export async function getApiModels(): Promise { } } +export async function getApiPlans(): Promise { + try { + const res = await axios.get("/v1/plans"); + return res.data as Plan[]; + } catch (e) { + console.warn(e); + return []; + } +} + export async function getApiMarket(): Promise { try { const res = await axios.get("/v1/market"); diff --git a/app/src/components/app/AppProvider.tsx b/app/src/components/app/AppProvider.tsx index 7f662e1..a01139a 100644 --- a/app/src/components/app/AppProvider.tsx +++ b/app/src/components/app/AppProvider.tsx @@ -3,14 +3,19 @@ import { ThemeProvider } from "@/components/ThemeProvider.tsx"; import DialogManager from "@/dialogs"; import Broadcast from "@/components/Broadcast.tsx"; import { useEffectAsync } from "@/utils/hook.ts"; -import { allModels, supportModels } from "@/conf"; +import { allModels, subscriptionData, supportModels } from "@/conf"; import { channelModels } from "@/admin/channel.ts"; -import { getApiCharge, getApiMarket, getApiModels } from "@/api/v1.ts"; +import { + getApiCharge, + getApiMarket, + getApiModels, + getApiPlans, +} from "@/api/v1.ts"; import { loadPreferenceModels } from "@/utils/storage.ts"; import { resetJsArray } from "@/utils/base.ts"; import { useDispatch } from "react-redux"; import { initChatModels } from "@/store/chat.ts"; -import { Model } from "@/api/types.ts"; +import { Model, Plan } from "@/api/types.ts"; import { ChargeProps, nonBilling } from "@/admin/charge.ts"; function AppProvider() { @@ -40,6 +45,12 @@ function AppProvider() { if (!allModels.includes(model)) allModels.push(model); if (!channelModels.includes(model)) channelModels.push(model); }); + + const plans = await getApiPlans(); + resetJsArray( + subscriptionData, + plans.filter((plan: Plan) => plan.level !== 0), + ); }, [allModels]); return ( diff --git a/app/src/components/app/MenuBar.tsx b/app/src/components/app/MenuBar.tsx index a53596d..aa784c3 100644 --- a/app/src/components/app/MenuBar.tsx +++ b/app/src/components/app/MenuBar.tsx @@ -29,6 +29,7 @@ import { openDialog as openApiDialog } from "@/store/api.ts"; import router from "@/router.tsx"; import { useDeeptrain } from "@/conf/env.ts"; import React from "react"; +import { subscriptionData } from "@/conf"; type MenuBarProps = { children: React.ReactNode; @@ -61,10 +62,12 @@ function MenuBar({ children, className }: MenuBarProps) { )} - dispatch(openSub())}> - - {t("sub.title")} - + {subscriptionData.length > 0 && ( + dispatch(openSub())}> + + {t("sub.title")} + + )} {useDeeptrain && ( dispatch(openPackageDialog())}> diff --git a/app/src/components/home/ModelFinder.tsx b/app/src/components/home/ModelFinder.tsx index 5f3b79f..e1640bb 100644 --- a/app/src/components/home/ModelFinder.tsx +++ b/app/src/components/home/ModelFinder.tsx @@ -1,5 +1,5 @@ import SelectGroup, { SelectItemProps } from "@/components/SelectGroup.tsx"; -import { supportModels } from "@/conf"; +import { subscriptionData, supportModels } from "@/conf"; import { openMarket, selectModel, @@ -84,7 +84,7 @@ function ModelFinder(props: ModelSelectorProps) { value: t("market.model"), }, ]; - }, [supportModels, level, student, sync]); + }, [supportModels, subscriptionData, level, student, sync]); const current = useMemo((): SelectItemProps => { const raw = models.find((item) => item.name === model); diff --git a/app/src/components/home/ModelMarket.tsx b/app/src/components/home/ModelMarket.tsx index 3a4517d..b080a1a 100644 --- a/app/src/components/home/ModelMarket.tsx +++ b/app/src/components/home/ModelMarket.tsx @@ -11,7 +11,7 @@ import { X, } from "lucide-react"; import React, { useMemo, useState } from "react"; -import { supportModels } from "@/conf"; +import { subscriptionData, supportModels } from "@/conf"; import { isUrl, splitList } from "@/utils/base.ts"; import { Model } from "@/api/types.ts"; import { useDispatch, useSelector } from "react-redux"; @@ -123,7 +123,7 @@ function ModelItem({ const pro = useMemo(() => { return includingModelFromPlan(level, model.id); - }, [model, level, student]); + }, [subscriptionData, model, level, student]); const avatar = useMemo(() => { return isUrl(model.avatar) ? model.avatar : `/icons/${model.avatar}`; diff --git a/app/src/conf/index.ts b/app/src/conf/index.ts index 047b25a..c7c156a 100644 --- a/app/src/conf/index.ts +++ b/app/src/conf/index.ts @@ -12,110 +12,15 @@ import { setAxiosConfig } from "@/conf/api.ts"; export const version = "3.8.6"; // version of the current build export const dev: boolean = getDev(); // is in development mode (for debugging, in localhost origin) export const deploy: boolean = true; // is production environment (for api endpoint) +export const tokenField = getTokenField(deploy); // token field name for storing token export let apiEndpoint: string = getRestApi(deploy); // api endpoint for rest api calls export let websocketEndpoint: string = getWebsocketApi(deploy); // api endpoint for websocket calls -export const tokenField = getTokenField(deploy); // token field name for storing token export let supportModels: Model[] = loadPreferenceModels(getOfflineModels()); // support models in model market of the current site export let allModels: string[] = supportModels.map((model) => model.id); // all support model id list of the current site -const GPT4Array = [ - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4-v", - "gpt-4-dalle", - "gpt-4-all", -]; -const Claude100kArray = ["claude-1.3", "claude-2", "claude-2.1"]; -const MidjourneyArray = ["midjourney-fast"]; - -export const subscriptionData: Plans = [ - { - level: 1, - price: 42, - items: [ - { - id: "gpt-4", - icon: "compass", - name: "GPT-4", - value: 150, - models: GPT4Array, - }, - { - id: "midjourney", - icon: "image-plus", - name: "Midjourney", - value: 50, - models: MidjourneyArray, - }, - { - id: "claude-100k", - icon: "book-text", - name: "Claude 100k", - value: 300, - models: Claude100kArray, - }, - ], - }, - { - level: 2, - price: 76, - items: [ - { - id: "gpt-4", - icon: "compass", - name: "GPT-4", - value: 300, - models: GPT4Array, - }, - { - id: "midjourney", - icon: "image-plus", - name: "Midjourney", - value: 100, - models: MidjourneyArray, - }, - { - id: "claude-100k", - icon: "book-text", - name: "Claude 100k", - value: 600, - models: Claude100kArray, - }, - ], - }, - { - level: 3, - price: 148, - items: [ - { - id: "gpt-4", - icon: "compass", - name: "GPT-4", - value: 600, - models: GPT4Array, - }, - { - id: "midjourney", - icon: "image-plus", - name: "Midjourney", - value: 200, - models: MidjourneyArray, - }, - { - id: "claude-100k", - icon: "book-text", - name: "Claude 100k", - value: 1200, - models: Claude100kArray, - }, - ], - }, -]; +export let subscriptionData: Plans = []; // subscription data of the current site setAxiosConfig({ endpoint: apiEndpoint, diff --git a/app/src/conf/subscription.tsx b/app/src/conf/subscription.tsx index 7911174..a422c09 100644 --- a/app/src/conf/subscription.tsx +++ b/app/src/conf/subscription.tsx @@ -41,7 +41,11 @@ export function SubscriptionIcon({ type, className }: SubscriptionIconProps) { export function getPlan(level: number): Plan { const raw = subscriptionData.filter((item) => item.level === level); - return raw.length > 0 ? raw[0] : subscriptionData[0]; + return raw.length > 0 + ? raw[0] + : subscriptionData.length + ? subscriptionData[0] + : { level: 0, price: 0, items: [] }; } export function getPlanModels(level: number): string[] { diff --git a/app/src/dialogs/QuotaDialog.tsx b/app/src/dialogs/QuotaDialog.tsx index 79ced17..ef78635 100644 --- a/app/src/dialogs/QuotaDialog.tsx +++ b/app/src/dialogs/QuotaDialog.tsx @@ -42,6 +42,7 @@ import { ToastAction } from "@/components/ui/toast.tsx"; import { deeptrainEndpoint, docsEndpoint, useDeeptrain } from "@/conf/env.ts"; import { useRedeem } from "@/api/redeem.ts"; import { cn } from "@/components/ui/lib/utils.ts"; +import { subscriptionData } from "@/conf"; type AmountComponentProps = { amount: number; @@ -105,14 +106,16 @@ function QuotaDialog() { {t("buy.choose")}
-

- sub ? dispatch(closeDialog()) : dispatch(openSubDialog()) - } - > - {t("sub.subscription-link")} -

+ {subscriptionData.length > 0 && ( +

+ sub ? dispatch(closeDialog()) : dispatch(openSubDialog()) + } + > + {t("sub.subscription-link")} +

+ )}
diff --git a/app/src/dialogs/SubscriptionDialog.tsx b/app/src/dialogs/SubscriptionDialog.tsx index 33d25ae..62c7741 100644 --- a/app/src/dialogs/SubscriptionDialog.tsx +++ b/app/src/dialogs/SubscriptionDialog.tsx @@ -46,8 +46,8 @@ type PlanItemProps = { function PlanItem({ level }: PlanItemProps) { const { t } = useTranslation(); const current = useSelector(levelSelector); - const plan = useMemo(() => getPlan(level), [level]); - const name = useMemo(() => getPlanName(level), [level]); + const plan = useMemo(() => getPlan(level), [subscriptionData, level]); + const name = useMemo(() => getPlanName(level), [subscriptionData, level]); return (
diff --git a/app/src/store/subscription.ts b/app/src/store/subscription.ts index bcab9e0..7cbee89 100644 --- a/app/src/store/subscription.ts +++ b/app/src/store/subscription.ts @@ -1,6 +1,7 @@ import { createSlice } from "@reduxjs/toolkit"; import { getSubscription } from "@/api/addition.ts"; import { AppDispatch } from "./index.ts"; +import { subscriptionData } from "@/conf"; export const subscriptionSlice = createSlice({ name: "subscription", @@ -16,12 +17,14 @@ export const subscriptionSlice = createSlice({ }, reducers: { toggleDialog: (state) => { + if (!state.dialog && !subscriptionData.length) return; state.dialog = !state.dialog; }, setDialog: (state, action) => { state.dialog = action.payload as boolean; }, openDialog: (state) => { + if (!subscriptionData.length) return; state.dialog = true; }, closeDialog: (state) => { diff --git a/auth/plan.go b/auth/plan.go index 9837d15..e23b8de 100644 --- a/auth/plan.go +++ b/auth/plan.go @@ -12,54 +12,69 @@ import ( ) type Plan struct { - Level int - Price float32 - Usage []PlanUsage + Level int `json:"level" mapstructure:"level"` + Price float32 `json:"price" mapstructure:"price"` + Items []PlanItem `json:"items" mapstructure:"items"` } -type PlanUsage struct { - Id string - Value int64 - Including func(string) bool +type PlanItem struct { + Id string `json:"id" mapstructure:"id"` + Name string `json:"name" mapstructure:"name"` + Icon string `json:"icon" mapstructure:"icon"` + Value int64 `json:"value" mapstructure:"value"` + Models []string `json:"models" mapstructure:"models"` } type Usage struct { - Used int64 `json:"used"` - Total int64 `json:"total"` + Used int64 `json:"used" mapstructure:"used"` + Total int64 `json:"total" mapstructure:"total"` } type UsageMap map[string]Usage +var GPT4Array = []string{ + globals.GPT4, globals.GPT40314, globals.GPT40613, globals.GPT41106Preview, globals.GPT41106VisionPreview, + globals.GPT4Vision, globals.GPT4Dalle, globals.GPT4All, +} + +var ClaudeProArray = []string{ + globals.Claude1100k, globals.Claude2100k, globals.Claude2200k, +} + +var MidjourneyArray = []string{ + globals.MidjourneyFast, +} + var Plans = []Plan{ { Level: 0, Price: 0, - Usage: []PlanUsage{}, + Items: []PlanItem{}, }, { Level: 1, Price: 42, - Usage: []PlanUsage{ - {Id: "gpt-4", Value: 150, Including: globals.IsGPT4NativeModel}, - {Id: "claude-100k", Value: 300, Including: globals.IsClaude100KModel}, - {Id: "midjourney", Value: 50, Including: globals.IsMidjourneyFastModel}, + Items: []PlanItem{ + {Id: "gpt-4", Value: 150, Models: GPT4Array, Name: "GPT-4", Icon: "compass"}, + {Id: "midjourney", Value: 50, Models: MidjourneyArray, Name: "Midjourney", Icon: "image-plus"}, + {Id: "claude-100k", Value: 300, Models: ClaudeProArray, Name: "Claude 100k", Icon: "book-text"}, }, }, { Level: 2, Price: 76, - Usage: []PlanUsage{ - {Id: "gpt-4", Value: 300, Including: globals.IsGPT4NativeModel}, - {Id: "claude-100k", Value: 600, Including: globals.IsClaude100KModel}, - {Id: "midjourney", Value: 100, Including: globals.IsMidjourneyFastModel}, + Items: []PlanItem{ + {Id: "gpt-4", Value: 300, Models: GPT4Array, Name: "GPT-4", Icon: "compass"}, + {Id: "midjourney", Value: 100, Models: MidjourneyArray, Name: "Midjourney", Icon: "image-plus"}, + {Id: "claude-100k", Value: 600, Models: ClaudeProArray, Name: "Claude 100k", Icon: "book-text"}, }, }, { Level: 3, Price: 148, - Usage: []PlanUsage{ - {Id: "gpt-4", Value: 600, Including: globals.IsGPT4NativeModel}, - {Id: "claude-100k", Value: 1200, Including: globals.IsClaude100KModel}, - {Id: "midjourney", Value: 200, Including: globals.IsMidjourneyFastModel}, + Items: []PlanItem{ + {Id: "gpt-4", Value: 600, Models: GPT4Array, Name: "GPT-4", Icon: "compass"}, + {Id: "midjourney", Value: 200, Models: MidjourneyArray, Name: "Midjourney", Icon: "image-plus"}, + {Id: "claude-100k", Value: 1200, Models: ClaudeProArray, Name: "Claude 100k", Icon: "book-text"}, }, }, } @@ -155,19 +170,19 @@ func DecreaseSubscriptionUsage(cache *redis.Client, user *User, t string) bool { } func (p *Plan) GetUsage(user *User, db *sql.DB, cache *redis.Client) UsageMap { - return utils.EachObject[PlanUsage, Usage](p.Usage, func(usage PlanUsage) (string, Usage) { + return utils.EachObject[PlanItem, Usage](p.Items, func(usage PlanItem) (string, Usage) { return usage.Id, usage.GetUsageForm(user, db, cache) }) } -func (p *PlanUsage) GetUsage(user *User, db *sql.DB, cache *redis.Client) int64 { +func (p *PlanItem) GetUsage(user *User, db *sql.DB, cache *redis.Client) int64 { // preflight check user.GetID(db) usage, _ := GetSubscriptionUsage(cache, user, p.Id) return usage } -func (p *PlanUsage) ResetUsage(user *User, cache *redis.Client) bool { +func (p *PlanItem) ResetUsage(user *User, cache *redis.Client) bool { key := globals.GetSubscriptionLimitFormat(p.Id, user.ID) _, offset := GetSubscriptionUsage(cache, user, p.Id) @@ -175,36 +190,36 @@ func (p *PlanUsage) ResetUsage(user *User, cache *redis.Client) bool { return err == nil } -func (p *PlanUsage) CreateUsage(user *User, cache *redis.Client) bool { +func (p *PlanItem) CreateUsage(user *User, cache *redis.Client) bool { key := globals.GetSubscriptionLimitFormat(p.Id, user.ID) err := utils.SetCache(cache, key, getOffsetFormat(time.Now(), 0), planExp) return err == nil } -func (p *PlanUsage) GetUsageForm(user *User, db *sql.DB, cache *redis.Client) Usage { +func (p *PlanItem) GetUsageForm(user *User, db *sql.DB, cache *redis.Client) Usage { return Usage{ Used: p.GetUsage(user, db, cache), Total: p.Value, } } -func (p *PlanUsage) IsInfinity() bool { +func (p *PlanItem) IsInfinity() bool { return p.Value == -1 } -func (p *PlanUsage) IsExceeded(user *User, db *sql.DB, cache *redis.Client) bool { +func (p *PlanItem) IsExceeded(user *User, db *sql.DB, cache *redis.Client) bool { return p.IsInfinity() || p.GetUsage(user, db, cache) < p.Value } -func (p *PlanUsage) Increase(user *User, cache *redis.Client) bool { +func (p *PlanItem) Increase(user *User, cache *redis.Client) bool { if p.Value == -1 { return true } return IncreaseSubscriptionUsage(cache, user, p.Id, p.Value) } -func (p *PlanUsage) Decrease(user *User, cache *redis.Client) bool { +func (p *PlanItem) Decrease(user *User, cache *redis.Client) bool { if p.Value == -1 { return true } @@ -217,8 +232,8 @@ func (u *User) GetSubscriptionUsage(db *sql.DB, cache *redis.Client) UsageMap { } func (p *Plan) IncreaseUsage(user *User, cache *redis.Client, model string) bool { - for _, usage := range p.Usage { - if usage.Including(model) { + for _, usage := range p.Items { + if utils.Contains(model, usage.Models) { return usage.Increase(user, cache) } } @@ -227,8 +242,8 @@ func (p *Plan) IncreaseUsage(user *User, cache *redis.Client, model string) bool } func (p *Plan) DecreaseUsage(user *User, cache *redis.Client, model string) bool { - for _, usage := range p.Usage { - if usage.Including(model) { + for _, usage := range p.Items { + if utils.Contains(model, usage.Models) { return usage.Decrease(user, cache) } } diff --git a/auth/subscription.go b/auth/subscription.go index 1fc6254..9254151 100644 --- a/auth/subscription.go +++ b/auth/subscription.go @@ -146,7 +146,7 @@ func BuySubscription(db *sql.DB, cache *redis.Client, user *User, level int, mon // new subscription plan := user.GetPlan(db) - for _, usage := range plan.Usage { + for _, usage := range plan.Items { // create usage usage.CreateUsage(user, cache) } diff --git a/globals/variables.go b/globals/variables.go index 7e51168..c82af98 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -97,9 +97,8 @@ const ( SkylarkChat = "skylark-chat" ) -var GPT4Array = []string{ - GPT4, GPT40314, GPT40613, GPT41106Preview, GPT41106VisionPreview, - GPT4Vision, GPT4Dalle, GPT4All, +var DalleModels = []string{ + Dalle, Dalle2, Dalle3, } func in(value string, slice []string) bool { @@ -111,22 +110,13 @@ func in(value string, slice []string) bool { return false } -func IsGPT4NativeModel(model string) bool { - return in(model, GPT4Array) -} - func IsDalleModel(model string) bool { - return model == Dalle || model == Dalle2 || model == Dalle3 -} - -func IsClaude100KModel(model string) bool { - return model == Claude1100k || model == Claude2100k || model == Claude2200k -} - -func IsMidjourneyFastModel(model string) bool { - return model == MidjourneyFast + // using image generation api if model is in dalle models + return in(model, DalleModels) } func IsGPT41106VisionPreview(model string) bool { - return model == GPT41106VisionPreview || strings.Contains(model, GPT41106VisionPreview) + // enable openai image format for gpt-4-vision-preview model + return model == GPT41106VisionPreview || + strings.Contains(model, GPT41106VisionPreview) } diff --git a/manager/relay.go b/manager/relay.go index 31f369d..9e8b71e 100644 --- a/manager/relay.go +++ b/manager/relay.go @@ -2,6 +2,7 @@ package manager import ( "chat/admin" + "chat/auth" "chat/channel" "github.com/gin-gonic/gin" "net/http" @@ -19,6 +20,10 @@ func ChargeAPI(c *gin.Context) { c.JSON(http.StatusOK, channel.ChargeInstance.ListRules()) } +func PlanAPI(c *gin.Context) { + c.JSON(http.StatusOK, auth.Plans) +} + func sendErrorResponse(c *gin.Context, err error, types ...string) { var errType string if len(types) > 0 { diff --git a/manager/router.go b/manager/router.go index 7ca7404..3bb62cd 100644 --- a/manager/router.go +++ b/manager/router.go @@ -10,6 +10,7 @@ func Register(app *gin.RouterGroup) { app.GET("/v1/models", ModelAPI) app.GET("/v1/market", MarketAPI) app.GET("/v1/charge", ChargeAPI) + app.GET("/v1/plans", PlanAPI) app.GET("/dashboard/billing/usage", GetBillingUsage) app.GET("/dashboard/billing/subscription", GetSubscription) app.POST("/v1/chat/completions", ChatRelayAPI)