From 5affcbb4db1d8e89679901bfe6b13dc98a88c263 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Wed, 20 Dec 2023 19:33:01 +0800 Subject: [PATCH] feat: channel group --- adapter/adapter.go | 1 + adapter/request.go | 3 +- addition/generation/api.go | 22 +++-- addition/generation/generate.go | 4 +- addition/generation/prompt.go | 4 +- app/src/admin/channel.ts | 9 ++ app/src/assets/admin/channel.less | 8 ++ app/src/assets/ui.less | 4 + .../admin/assemblies/ChannelEditor.tsx | 87 +++++++++++++++++++ app/src/i18n.ts | 9 ++ app/src/translator/index.ts | 12 +++ app/tsconfig.node.json | 2 +- app/vite.config.ts | 2 + auth/struct.go | 21 +++++ channel/channel.go | 12 +++ channel/manager.go | 8 +- channel/ticker.go | 23 ++++- channel/types.go | 1 + channel/worker.go | 12 ++- manager/chat.go | 36 ++++---- manager/completions.go | 22 +++-- manager/transhipment.go | 4 +- 22 files changed, 251 insertions(+), 55 deletions(-) create mode 100644 app/src/translator/index.ts diff --git a/adapter/adapter.go b/adapter/adapter.go index 35a502e..2967c3d 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -23,6 +23,7 @@ import ( type RequestProps struct { MaxRetries *int Current int + Group string } type ChatProps struct { diff --git a/adapter/request.go b/adapter/request.go index c6702d9..4f37e91 100644 --- a/adapter/request.go +++ b/adapter/request.go @@ -36,7 +36,8 @@ func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.H } if props.Current < retries { - globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, err.Error())) + content := strings.Replace(err.Error(), "\n", "", -1) + globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, content)) return NewChatRequest(conf, props, hook) } } diff --git a/addition/generation/api.go b/addition/generation/api.go index 6fb19f9..fb44657 100644 --- a/addition/generation/api.go +++ b/addition/generation/api.go @@ -72,14 +72,20 @@ func GenerateAPI(c *gin.Context) { } var instance *utils.Buffer - hash, err := CreateGenerationWithCache(form.Model, form.Prompt, plan, func(buffer *utils.Buffer, data string) { - instance = buffer - conn.Send(globals.GenerationSegmentResponse{ - End: false, - Message: data, - Quota: buffer.GetQuota(), - }) - }) + hash, err := CreateGenerationWithCache( + auth.GetGroup(db, user), + form.Model, + form.Prompt, + plan, + func(buffer *utils.Buffer, data string) { + instance = buffer + conn.Send(globals.GenerationSegmentResponse{ + End: false, + Message: data, + Quota: buffer.GetQuota(), + }) + }, + ) if instance != nil && !plan && instance.GetQuota() > 0 && user != nil { user.UseQuota(db, instance.GetQuota()) diff --git a/addition/generation/generate.go b/addition/generation/generate.go index b58128e..3e931a3 100644 --- a/addition/generation/generate.go +++ b/addition/generation/generate.go @@ -6,10 +6,10 @@ import ( "fmt" ) -func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) { +func CreateGenerationWithCache(group, model, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) { hash, path := GetFolderByHash(model, prompt) if !utils.Exists(path) { - if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil { + if err := CreateGeneration(group, model, prompt, path, enableReverse, hook); err != nil { globals.Info(fmt.Sprintf("[project] error during generation %s (model %s): %s", prompt, model, err.Error())) return "", fmt.Errorf("error during generate project: %s", err.Error()) } diff --git a/addition/generation/prompt.go b/addition/generation/prompt.go index 5c698c2..6c71c4c 100644 --- a/addition/generation/prompt.go +++ b/addition/generation/prompt.go @@ -13,11 +13,11 @@ type ProjectResult struct { Result map[string]interface{} `json:"result"` } -func CreateGeneration(model string, prompt string, path string, plan bool, hook func(buffer *utils.Buffer, data string)) error { +func CreateGeneration(group, model, prompt, path string, plan bool, hook func(buffer *utils.Buffer, data string)) error { message := GenerateMessage(prompt) buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model)) - err := channel.NewChatRequest(&adapter.ChatProps{ + err := channel.NewChatRequest(group, &adapter.ChatProps{ Model: model, Message: message, Plan: plan, diff --git a/app/src/admin/channel.ts b/app/src/admin/channel.ts index 8b83101..6772457 100644 --- a/app/src/admin/channel.ts +++ b/app/src/admin/channel.ts @@ -10,6 +10,7 @@ export type Channel = { endpoint: string; mapper: string; state: boolean; + group?: string[]; }; export type ChannelInfo = { @@ -164,6 +165,14 @@ export const channelModels: string[] = Object.values(ChannelInfos).flatMap( (info) => info.models, ); +export const channelGroups: string[] = [ + "anonymous", + "normal", + "basic", + "standard", + "pro", +]; + export function getChannelInfo(type?: string): ChannelInfo { if (type && type in ChannelInfos) return ChannelInfos[type]; return ChannelInfos.openai; diff --git a/app/src/assets/admin/channel.less b/app/src/assets/admin/channel.less index 3de235d..8dbd641 100644 --- a/app/src/assets/admin/channel.less +++ b/app/src/assets/admin/channel.less @@ -75,6 +75,7 @@ border-radius: var(--radius); transition: .25s; height: max-content; + white-space: break-spaces; &:hover { border-color: hsl(var(--border-hover)); @@ -87,6 +88,7 @@ margin-left: 0.5rem; color: hsl(var(--text-secondary)); transition: .25s; + flex-shrink: 0; &:hover { color: hsl(var(--text-primary)); @@ -101,6 +103,12 @@ width: 100%; flex-wrap: wrap; gap: 0.5rem; + + @media (max-width: 620px) { + & > * { + width: 100%; + } + } } .channel-description { diff --git a/app/src/assets/ui.less b/app/src/assets/ui.less index b215619..6e46082 100644 --- a/app/src/assets/ui.less +++ b/app/src/assets/ui.less @@ -1,3 +1,7 @@ +.gold-text { + color: hsl(var(--gold)) !important; +} + .select-group { display: flex; flex-direction: row; diff --git a/app/src/components/admin/assemblies/ChannelEditor.tsx b/app/src/components/admin/assemblies/ChannelEditor.tsx index 0c63776..81d0e25 100644 --- a/app/src/components/admin/assemblies/ChannelEditor.tsx +++ b/app/src/components/admin/assemblies/ChannelEditor.tsx @@ -10,6 +10,7 @@ import { } from "@/components/ui/select.tsx"; import { Channel, + channelGroups, channelModels, ChannelTypes, getChannelInfo, @@ -54,6 +55,7 @@ const initialState: Channel = { endpoint: getChannelInfo().endpoint, mapper: "", state: true, + group: [], }; type CustomActionProps = { @@ -133,6 +135,18 @@ function reducer(state: Channel, action: any) { return { ...state, retry: action.value }; case "clear": return { ...initialState }; + case "add-group": + return { + ...state, + group: state.group ? [...state.group, action.value] : [action.value], + }; + case "remove-group": + return { + ...state, + group: state.group + ? state.group.filter((group) => group !== action.value) + : [], + }; case "set": return { ...state, ...action.value }; default: @@ -171,6 +185,9 @@ function handler(data: Channel): Channel { ); }) .join("\n"); + data.group = data.group + ? data.group.filter((group) => group.trim() !== "") + : []; return data; } @@ -191,6 +208,13 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) { }, [edit.models]); const enabled = useMemo(() => validator(edit), [edit]); + const unusedGroups = useMemo(() => { + if (!edit.group) return channelGroups; + return channelGroups.filter( + (group) => !edit.group.includes(group) && group !== "", + ); + }, [edit.group]); + function close(clear?: boolean) { if (clear) dispatch({ type: "clear" }); setEnabled(false); @@ -409,6 +433,69 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) { } /> +
+
+ {t("admin.channels.group")} + +
+
+
+ {(edit.group || []).map((item: string, idx: number) => ( +
+ {item} + + dispatch({ type: "remove-group", value: item }) + } + /> +
+ ))} +
+ + + + + + + + {unusedGroups.length === 0 ? ( +

+ {t("conversation.empty")} +

+ ) : ( + unusedGroups.map((item, idx) => ( + + dispatch({ type: "add-group", value: item }) + } + className={`px-2 ${idx > 1 ? "gold-text" : ""}`} + > + {item} + + )) + )} +
+
+
+
+
+
diff --git a/app/src/i18n.ts b/app/src/i18n.ts index e4021a9..a8000a9 100644 --- a/app/src/i18n.ts +++ b/app/src/i18n.ts @@ -397,6 +397,9 @@ const resources = { "请输入模型映射,一行一个,格式: model>model\n" + "前者为请求的模型,后者为映射的模型(需要在模型中存在),中间用 > 分隔\n" + "格式前加!表示原模型不包含在此渠道的可用范围内,如: !gpt-4-slow>gpt-4,那么 gpt-4 将不会被涵盖在此渠道的可请求模型中", + group: "用户分组", + "group-tip": + "用户分组,未包含的分组将不包含在此渠道的可用范围内 (分组为空时,所有用户都可以使用此渠道)", state: "状态", action: "操作", edit: "编辑渠道", @@ -862,6 +865,9 @@ const resources = { "Please enter the model mapper, one line each, format: model>model\n" + "The former is the requested model, and the latter is the mapped model (which needs to exist in the model), separated by > in the middle\n" + "The format is preceded by! Indicates that the original model is not included in the available range of this channel, such as: !gpt-4-slow>gpt-4, then gpt-4 will not be covered in the available models that can be requested in this channel", + group: "User Group", + "group-tip": + "User group, the group that is not included will not be included in the available range of this channel (when the group is empty, all users can use this channel)", state: "State", action: "Action", edit: "Edit Channel", @@ -1331,6 +1337,9 @@ const resources = { "Введите модельный маппер, по одной строке, формат: model>model\n" + "Первая модель - запрошенная модель, вторая модель - отображаемая модель (которая должна существовать в модели), разделенная > посередине\n" + "Формат предшествует! Означает, что исходная модель не включена в доступный диапазон этого канала, например: !gpt-4-slow>gpt-4, тогда gpt-4 не будет охвачен в доступных моделях, которые можно запросить в этом канале", + group: "Группа пользователей", + "group-tip": + "Группа пользователей, группа, которая не включена, не будет включена в доступный диапазон этого канала (когда группа пуста, все пользователи могут использовать этот канал)", state: "Статус", action: "Действие", edit: "Редактировать канал", diff --git a/app/src/translator/index.ts b/app/src/translator/index.ts new file mode 100644 index 0000000..5603464 --- /dev/null +++ b/app/src/translator/index.ts @@ -0,0 +1,12 @@ +import { Plugin } from "vite"; +import path from "path"; + +export function createTranslationPlugin(): Plugin { + return { + name: "translate-plugin", + apply: "build", + configResolved(config) { + const dir = path.resolve(config.root, "src"); + }, + }; +} diff --git a/app/tsconfig.node.json b/app/tsconfig.node.json index 42872c5..0bd00f5 100644 --- a/app/tsconfig.node.json +++ b/app/tsconfig.node.json @@ -6,5 +6,5 @@ "moduleResolution": "bundler", "allowSyntheticDefaultImports": true }, - "include": ["vite.config.ts"] + "include": ["vite.config.ts", "src/translator"] } diff --git a/app/vite.config.ts b/app/vite.config.ts index 5a7cdea..ed41aa4 100644 --- a/app/vite.config.ts +++ b/app/vite.config.ts @@ -2,6 +2,7 @@ import { defineConfig } from 'vite' import react from '@vitejs/plugin-react-swc' import path from "path" import { createHtmlPlugin } from 'vite-plugin-html' +import { createTranslationPlugin } from "./src/translator"; // https://vitejs.dev/config/ export default defineConfig({ @@ -10,6 +11,7 @@ export default defineConfig({ createHtmlPlugin({ minify: true, }), + createTranslationPlugin(), ], resolve: { alias: { diff --git a/auth/struct.go b/auth/struct.go index e13da4a..250ba46 100644 --- a/auth/struct.go +++ b/auth/struct.go @@ -1,6 +1,7 @@ package auth import ( + "chat/globals" "database/sql" "time" ) @@ -62,3 +63,23 @@ func IsUserExist(db *sql.DB, username string) bool { } return count > 0 } + +func GetGroup(db *sql.DB, user *User) string { + if user == nil { + return globals.AnonymousType + } + + level := user.GetSubscriptionLevel(db) + switch level { + case 0: + return globals.NormalType + case 1: + return globals.BasicType + case 2: + return globals.StandardType + case 3: + return globals.ProType + default: + return globals.NormalType + } +} diff --git a/channel/channel.go b/channel/channel.go index 29d44c9..de06e14 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -158,6 +158,18 @@ func (c *Channel) GetState() bool { return c.State } +func (c *Channel) GetGroup() []string { + return c.Group +} + +func (c *Channel) IsHitGroup(group string) bool { + if len(c.GetGroup()) == 0 { + return true + } + + return utils.Contains(group, c.GetGroup()) +} + func (c *Channel) IsHit(model string) bool { return utils.Contains(model, c.GetHitModels()) } diff --git a/channel/manager.go b/channel/manager.go index 68f4ebb..7b9ad68 100644 --- a/channel/manager.go +++ b/channel/manager.go @@ -99,10 +99,12 @@ func (m *Manager) HasChannel(model string) bool { return utils.Contains(model, m.Models) } -func (m *Manager) GetTicker(model string) *Ticker { - return &Ticker{ - Sequence: m.HitSequence(model), +func (m *Manager) GetTicker(model, group string) *Ticker { + if !m.HasChannel(model) { + return nil } + + return NewTicker(m.HitSequence(model), group) } func (m *Manager) Len() int { diff --git a/channel/ticker.go b/channel/ticker.go index 436ad68..4b56406 100644 --- a/channel/ticker.go +++ b/channel/ticker.go @@ -2,6 +2,21 @@ package channel import "chat/utils" +func NewTicker(seq Sequence, group string) *Ticker { + stack := make(Sequence, 0) + for _, channel := range seq { + if channel.IsHitGroup(group) { + stack = append(stack, channel) + } + } + + stack.Sort() + + return &Ticker{ + Sequence: stack, + } +} + func (t *Ticker) GetChannelByPriority(priority int) *Channel { var stack Sequence @@ -72,10 +87,10 @@ func (t *Ticker) SkipPriority(priority int) { } } -func (t *Ticker) Skip() { - t.Cursor++ -} - func (t *Ticker) IsDone() bool { return t.Cursor >= len(t.Sequence) } + +func (t *Ticker) IsEmpty() bool { + return len(t.Sequence) == 0 +} diff --git a/channel/types.go b/channel/types.go index a5be364..23a0e57 100644 --- a/channel/types.go +++ b/channel/types.go @@ -12,6 +12,7 @@ type Channel struct { Endpoint string `json:"endpoint" mapstructure:"endpoint"` Mapper string `json:"mapper" mapstructure:"mapper"` State bool `json:"state" mapstructure:"state"` + Group []string `json:"group" mapstructure:"group"` Reflect *map[string]string `json:"-"` HitModels *[]string `json:"-"` ExcludeModels *[]string `json:"-"` diff --git a/channel/worker.go b/channel/worker.go index 386109d..ab9070b 100644 --- a/channel/worker.go +++ b/channel/worker.go @@ -5,16 +5,14 @@ import ( "chat/globals" "chat/utils" "fmt" - "github.com/cloudwego/hertz/cmd/hz/util/logs" ) -func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error { - if !ConduitInstance.HasChannel(props.Model) { +func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error { + ticker := ConduitInstance.GetTicker(props.Model, group) + if ticker == nil || ticker.IsEmpty() { return fmt.Errorf("cannot find channel for model %s", props.Model) } - ticker := ConduitInstance.GetTicker(props.Model) - var err error for !ticker.IsDone() { if channel := ticker.Next(); channel != nil { @@ -23,10 +21,10 @@ func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error { return nil } - logs.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.Model, channel.GetName())) + globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.Model, channel.GetName())) } } - logs.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model)) + globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model)) return err } diff --git a/manager/chat.go b/manager/chat.go index 1e8f5a8..c599e1a 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -88,23 +88,27 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve } buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) - err := channel.NewChatRequest(&adapter.ChatProps{ - Model: model, - Message: segment, - Plan: plan, - Buffer: *buffer, - }, func(data string) error { - if signal := conn.PeekWithType(StopType); signal != nil { - // stop signal from client - return fmt.Errorf("signal") - } - return conn.SendClient(globals.ChatSegmentResponse{ - Message: buffer.Write(data), - Quota: buffer.GetQuota(), - End: false, + err := channel.NewChatRequest( + auth.GetGroup(db, user), + &adapter.ChatProps{ + Model: model, + Message: segment, Plan: plan, - }) - }) + Buffer: *buffer, + }, + func(data string) error { + if signal := conn.PeekWithType(StopType); signal != nil { + // stop signal from client + return fmt.Errorf("signal") + } + return conn.SendClient(globals.ChatSegmentResponse{ + Message: buffer.Write(data), + Quota: buffer.GetQuota(), + End: false, + Plan: plan, + }) + }, + ) admin.AnalysisRequest(model, buffer, err) if err != nil && err.Error() != "signal" { diff --git a/manager/completions.go b/manager/completions.go index 934e566..f932482 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -40,15 +40,19 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] } buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) - err := channel.NewChatRequest(&adapter.ChatProps{ - Model: model, - Plan: plan, - Message: segment, - Buffer: *buffer, - }, func(resp string) error { - buffer.Write(resp) - return nil - }) + err := channel.NewChatRequest( + auth.GetGroup(db, user), + &adapter.ChatProps{ + Model: model, + Plan: plan, + Message: segment, + Buffer: *buffer, + }, + func(resp string) error { + buffer.Write(resp) + return nil + }, + ) admin.AnalysisRequest(model, buffer, err) if err != nil { diff --git a/manager/transhipment.go b/manager/transhipment.go index a0e0ba5..6b93014 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -182,7 +182,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, cache := utils.GetCacheFromContext(c) buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model)) - err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { + err := channel.NewChatRequest(auth.GetGroup(db, user), GetProps(form, buffer, plan), func(data string) error { buffer.Write(data) return nil }) @@ -251,7 +251,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st go func() { buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model)) - err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { + err := channel.NewChatRequest(auth.GetGroup(db, user), GetProps(form, buffer, plan), func(data string) error { partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil) return nil })