feat: channel group

This commit is contained in:
Zhang Minghan 2023-12-20 19:33:01 +08:00
parent b0174a5db3
commit 5affcbb4db
22 changed files with 251 additions and 55 deletions

View File

@ -23,6 +23,7 @@ import (
type RequestProps struct { type RequestProps struct {
MaxRetries *int MaxRetries *int
Current int Current int
Group string
} }
type ChatProps struct { type ChatProps struct {

View File

@ -36,7 +36,8 @@ func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.H
} }
if props.Current < retries { if props.Current < retries {
globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, err.Error())) content := strings.Replace(err.Error(), "\n", "", -1)
globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, content))
return NewChatRequest(conf, props, hook) return NewChatRequest(conf, props, hook)
} }
} }

View File

@ -72,14 +72,20 @@ func GenerateAPI(c *gin.Context) {
} }
var instance *utils.Buffer var instance *utils.Buffer
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, plan, func(buffer *utils.Buffer, data string) { hash, err := CreateGenerationWithCache(
auth.GetGroup(db, user),
form.Model,
form.Prompt,
plan,
func(buffer *utils.Buffer, data string) {
instance = buffer instance = buffer
conn.Send(globals.GenerationSegmentResponse{ conn.Send(globals.GenerationSegmentResponse{
End: false, End: false,
Message: data, Message: data,
Quota: buffer.GetQuota(), Quota: buffer.GetQuota(),
}) })
}) },
)
if instance != nil && !plan && instance.GetQuota() > 0 && user != nil { if instance != nil && !plan && instance.GetQuota() > 0 && user != nil {
user.UseQuota(db, instance.GetQuota()) user.UseQuota(db, instance.GetQuota())

View File

@ -6,10 +6,10 @@ import (
"fmt" "fmt"
) )
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) { func CreateGenerationWithCache(group, model, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) {
hash, path := GetFolderByHash(model, prompt) hash, path := GetFolderByHash(model, prompt)
if !utils.Exists(path) { if !utils.Exists(path) {
if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil { if err := CreateGeneration(group, model, prompt, path, enableReverse, hook); err != nil {
globals.Info(fmt.Sprintf("[project] error during generation %s (model %s): %s", prompt, model, err.Error())) globals.Info(fmt.Sprintf("[project] error during generation %s (model %s): %s", prompt, model, err.Error()))
return "", fmt.Errorf("error during generate project: %s", err.Error()) return "", fmt.Errorf("error during generate project: %s", err.Error())
} }

View File

@ -13,11 +13,11 @@ type ProjectResult struct {
Result map[string]interface{} `json:"result"` Result map[string]interface{} `json:"result"`
} }
func CreateGeneration(model string, prompt string, path string, plan bool, hook func(buffer *utils.Buffer, data string)) error { func CreateGeneration(group, model, prompt, path string, plan bool, hook func(buffer *utils.Buffer, data string)) error {
message := GenerateMessage(prompt) message := GenerateMessage(prompt)
buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model)) buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model))
err := channel.NewChatRequest(&adapter.ChatProps{ err := channel.NewChatRequest(group, &adapter.ChatProps{
Model: model, Model: model,
Message: message, Message: message,
Plan: plan, Plan: plan,

View File

@ -10,6 +10,7 @@ export type Channel = {
endpoint: string; endpoint: string;
mapper: string; mapper: string;
state: boolean; state: boolean;
group?: string[];
}; };
export type ChannelInfo = { export type ChannelInfo = {
@ -164,6 +165,14 @@ export const channelModels: string[] = Object.values(ChannelInfos).flatMap(
(info) => info.models, (info) => info.models,
); );
export const channelGroups: string[] = [
"anonymous",
"normal",
"basic",
"standard",
"pro",
];
export function getChannelInfo(type?: string): ChannelInfo { export function getChannelInfo(type?: string): ChannelInfo {
if (type && type in ChannelInfos) return ChannelInfos[type]; if (type && type in ChannelInfos) return ChannelInfos[type];
return ChannelInfos.openai; return ChannelInfos.openai;

View File

@ -75,6 +75,7 @@
border-radius: var(--radius); border-radius: var(--radius);
transition: .25s; transition: .25s;
height: max-content; height: max-content;
white-space: break-spaces;
&:hover { &:hover {
border-color: hsl(var(--border-hover)); border-color: hsl(var(--border-hover));
@ -87,6 +88,7 @@
margin-left: 0.5rem; margin-left: 0.5rem;
color: hsl(var(--text-secondary)); color: hsl(var(--text-secondary));
transition: .25s; transition: .25s;
flex-shrink: 0;
&:hover { &:hover {
color: hsl(var(--text-primary)); color: hsl(var(--text-primary));
@ -101,6 +103,12 @@
width: 100%; width: 100%;
flex-wrap: wrap; flex-wrap: wrap;
gap: 0.5rem; gap: 0.5rem;
@media (max-width: 620px) {
& > * {
width: 100%;
}
}
} }
.channel-description { .channel-description {

View File

@ -1,3 +1,7 @@
.gold-text {
color: hsl(var(--gold)) !important;
}
.select-group { .select-group {
display: flex; display: flex;
flex-direction: row; flex-direction: row;

View File

@ -10,6 +10,7 @@ import {
} from "@/components/ui/select.tsx"; } from "@/components/ui/select.tsx";
import { import {
Channel, Channel,
channelGroups,
channelModels, channelModels,
ChannelTypes, ChannelTypes,
getChannelInfo, getChannelInfo,
@ -54,6 +55,7 @@ const initialState: Channel = {
endpoint: getChannelInfo().endpoint, endpoint: getChannelInfo().endpoint,
mapper: "", mapper: "",
state: true, state: true,
group: [],
}; };
type CustomActionProps = { type CustomActionProps = {
@ -133,6 +135,18 @@ function reducer(state: Channel, action: any) {
return { ...state, retry: action.value }; return { ...state, retry: action.value };
case "clear": case "clear":
return { ...initialState }; return { ...initialState };
case "add-group":
return {
...state,
group: state.group ? [...state.group, action.value] : [action.value],
};
case "remove-group":
return {
...state,
group: state.group
? state.group.filter((group) => group !== action.value)
: [],
};
case "set": case "set":
return { ...state, ...action.value }; return { ...state, ...action.value };
default: default:
@ -171,6 +185,9 @@ function handler(data: Channel): Channel {
); );
}) })
.join("\n"); .join("\n");
data.group = data.group
? data.group.filter((group) => group.trim() !== "")
: [];
return data; return data;
} }
@ -191,6 +208,13 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
}, [edit.models]); }, [edit.models]);
const enabled = useMemo(() => validator(edit), [edit]); const enabled = useMemo(() => validator(edit), [edit]);
const unusedGroups = useMemo(() => {
if (!edit.group) return channelGroups;
return channelGroups.filter(
(group) => !edit.group.includes(group) && group !== "",
);
}, [edit.group]);
function close(clear?: boolean) { function close(clear?: boolean) {
if (clear) dispatch({ type: "clear" }); if (clear) dispatch({ type: "clear" });
setEnabled(false); setEnabled(false);
@ -409,6 +433,69 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
} }
/> />
</div> </div>
<div className={`channel-row`}>
<div className={`channel-content`}>
{t("admin.channels.group")}
<Tips content={t("admin.channels.group-tip")} />
</div>
<div className={`flex flex-row gap-2 items-center`}>
<div
className={`channel-model-wrapper`}
style={{ minHeight: "2.5rem" }}
>
{(edit.group || []).map((item: string, idx: number) => (
<div className={`channel-model-item`} key={idx}>
{item}
<X
className={`remove-action`}
onClick={() =>
dispatch({ type: "remove-group", value: item })
}
/>
</div>
))}
</div>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
disabled={unusedGroups.length === 0}
className={`h-full`}
>
<Plus className={`w-4 h-4 mr-1`} />
{t("add")}
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align={`end`} asChild>
<Command>
<CommandList
className={
unusedGroups.length ? `thin-scrollbar` : `no-scrollbar`
}
>
{unusedGroups.length === 0 ? (
<p className={`p-2 text-center`}>
{t("conversation.empty")}
</p>
) : (
unusedGroups.map((item, idx) => (
<CommandItem
key={idx}
value={item}
onSelect={() =>
dispatch({ type: "add-group", value: item })
}
className={`px-2 ${idx > 1 ? "gold-text" : ""}`}
>
{item}
</CommandItem>
))
)}
</CommandList>
</Command>
</DropdownMenuContent>
</DropdownMenu>
</div>
</div>
</div> </div>
<div className={`mt-4 flex flex-row w-full h-max pr-2 items-center`}> <div className={`mt-4 flex flex-row w-full h-max pr-2 items-center`}>
<div className={`object-id`}> <div className={`object-id`}>

View File

@ -397,6 +397,9 @@ const resources = {
"请输入模型映射,一行一个,格式: model>model\n" + "请输入模型映射,一行一个,格式: model>model\n" +
"前者为请求的模型,后者为映射的模型(需要在模型中存在),中间用 > 分隔\n" + "前者为请求的模型,后者为映射的模型(需要在模型中存在),中间用 > 分隔\n" +
"格式前加!表示原模型不包含在此渠道的可用范围内,如: !gpt-4-slow>gpt-4那么 gpt-4 将不会被涵盖在此渠道的可请求模型中", "格式前加!表示原模型不包含在此渠道的可用范围内,如: !gpt-4-slow>gpt-4那么 gpt-4 将不会被涵盖在此渠道的可请求模型中",
group: "用户分组",
"group-tip":
"用户分组,未包含的分组将不包含在此渠道的可用范围内 (分组为空时,所有用户都可以使用此渠道)",
state: "状态", state: "状态",
action: "操作", action: "操作",
edit: "编辑渠道", edit: "编辑渠道",
@ -862,6 +865,9 @@ const resources = {
"Please enter the model mapper, one line each, format: model>model\n" + "Please enter the model mapper, one line each, format: model>model\n" +
"The former is the requested model, and the latter is the mapped model (which needs to exist in the model), separated by > in the middle\n" + "The former is the requested model, and the latter is the mapped model (which needs to exist in the model), separated by > in the middle\n" +
"The format is preceded by! Indicates that the original model is not included in the available range of this channel, such as: !gpt-4-slow>gpt-4, then gpt-4 will not be covered in the available models that can be requested in this channel", "The format is preceded by! Indicates that the original model is not included in the available range of this channel, such as: !gpt-4-slow>gpt-4, then gpt-4 will not be covered in the available models that can be requested in this channel",
group: "User Group",
"group-tip":
"User group, the group that is not included will not be included in the available range of this channel (when the group is empty, all users can use this channel)",
state: "State", state: "State",
action: "Action", action: "Action",
edit: "Edit Channel", edit: "Edit Channel",
@ -1331,6 +1337,9 @@ const resources = {
"Введите модельный маппер, по одной строке, формат: model>model\n" + "Введите модельный маппер, по одной строке, формат: model>model\n" +
"Первая модель - запрошенная модель, вторая модель - отображаемая модель (которая должна существовать в модели), разделенная > посередине\n" + "Первая модель - запрошенная модель, вторая модель - отображаемая модель (которая должна существовать в модели), разделенная > посередине\n" +
"Формат предшествует! Означает, что исходная модель не включена в доступный диапазон этого канала, например: !gpt-4-slow>gpt-4, тогда gpt-4 не будет охвачен в доступных моделях, которые можно запросить в этом канале", "Формат предшествует! Означает, что исходная модель не включена в доступный диапазон этого канала, например: !gpt-4-slow>gpt-4, тогда gpt-4 не будет охвачен в доступных моделях, которые можно запросить в этом канале",
group: "Группа пользователей",
"group-tip":
"Группа пользователей, группа, которая не включена, не будет включена в доступный диапазон этого канала (когда группа пуста, все пользователи могут использовать этот канал)",
state: "Статус", state: "Статус",
action: "Действие", action: "Действие",
edit: "Редактировать канал", edit: "Редактировать канал",

View File

@ -0,0 +1,12 @@
import { Plugin } from "vite";
import path from "path";
export function createTranslationPlugin(): Plugin {
return {
name: "translate-plugin",
apply: "build",
configResolved(config) {
const dir = path.resolve(config.root, "src");
},
};
}

View File

@ -6,5 +6,5 @@
"moduleResolution": "bundler", "moduleResolution": "bundler",
"allowSyntheticDefaultImports": true "allowSyntheticDefaultImports": true
}, },
"include": ["vite.config.ts"] "include": ["vite.config.ts", "src/translator"]
} }

View File

@ -2,6 +2,7 @@ import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react-swc' import react from '@vitejs/plugin-react-swc'
import path from "path" import path from "path"
import { createHtmlPlugin } from 'vite-plugin-html' import { createHtmlPlugin } from 'vite-plugin-html'
import { createTranslationPlugin } from "./src/translator";
// https://vitejs.dev/config/ // https://vitejs.dev/config/
export default defineConfig({ export default defineConfig({
@ -10,6 +11,7 @@ export default defineConfig({
createHtmlPlugin({ createHtmlPlugin({
minify: true, minify: true,
}), }),
createTranslationPlugin(),
], ],
resolve: { resolve: {
alias: { alias: {

View File

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"chat/globals"
"database/sql" "database/sql"
"time" "time"
) )
@ -62,3 +63,23 @@ func IsUserExist(db *sql.DB, username string) bool {
} }
return count > 0 return count > 0
} }
func GetGroup(db *sql.DB, user *User) string {
if user == nil {
return globals.AnonymousType
}
level := user.GetSubscriptionLevel(db)
switch level {
case 0:
return globals.NormalType
case 1:
return globals.BasicType
case 2:
return globals.StandardType
case 3:
return globals.ProType
default:
return globals.NormalType
}
}

View File

@ -158,6 +158,18 @@ func (c *Channel) GetState() bool {
return c.State return c.State
} }
func (c *Channel) GetGroup() []string {
return c.Group
}
func (c *Channel) IsHitGroup(group string) bool {
if len(c.GetGroup()) == 0 {
return true
}
return utils.Contains(group, c.GetGroup())
}
func (c *Channel) IsHit(model string) bool { func (c *Channel) IsHit(model string) bool {
return utils.Contains(model, c.GetHitModels()) return utils.Contains(model, c.GetHitModels())
} }

View File

@ -99,10 +99,12 @@ func (m *Manager) HasChannel(model string) bool {
return utils.Contains(model, m.Models) return utils.Contains(model, m.Models)
} }
func (m *Manager) GetTicker(model string) *Ticker { func (m *Manager) GetTicker(model, group string) *Ticker {
return &Ticker{ if !m.HasChannel(model) {
Sequence: m.HitSequence(model), return nil
} }
return NewTicker(m.HitSequence(model), group)
} }
func (m *Manager) Len() int { func (m *Manager) Len() int {

View File

@ -2,6 +2,21 @@ package channel
import "chat/utils" import "chat/utils"
func NewTicker(seq Sequence, group string) *Ticker {
stack := make(Sequence, 0)
for _, channel := range seq {
if channel.IsHitGroup(group) {
stack = append(stack, channel)
}
}
stack.Sort()
return &Ticker{
Sequence: stack,
}
}
func (t *Ticker) GetChannelByPriority(priority int) *Channel { func (t *Ticker) GetChannelByPriority(priority int) *Channel {
var stack Sequence var stack Sequence
@ -72,10 +87,10 @@ func (t *Ticker) SkipPriority(priority int) {
} }
} }
func (t *Ticker) Skip() {
t.Cursor++
}
func (t *Ticker) IsDone() bool { func (t *Ticker) IsDone() bool {
return t.Cursor >= len(t.Sequence) return t.Cursor >= len(t.Sequence)
} }
func (t *Ticker) IsEmpty() bool {
return len(t.Sequence) == 0
}

View File

@ -12,6 +12,7 @@ type Channel struct {
Endpoint string `json:"endpoint" mapstructure:"endpoint"` Endpoint string `json:"endpoint" mapstructure:"endpoint"`
Mapper string `json:"mapper" mapstructure:"mapper"` Mapper string `json:"mapper" mapstructure:"mapper"`
State bool `json:"state" mapstructure:"state"` State bool `json:"state" mapstructure:"state"`
Group []string `json:"group" mapstructure:"group"`
Reflect *map[string]string `json:"-"` Reflect *map[string]string `json:"-"`
HitModels *[]string `json:"-"` HitModels *[]string `json:"-"`
ExcludeModels *[]string `json:"-"` ExcludeModels *[]string `json:"-"`

View File

@ -5,16 +5,14 @@ import (
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"fmt" "fmt"
"github.com/cloudwego/hertz/cmd/hz/util/logs"
) )
func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error { func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error {
if !ConduitInstance.HasChannel(props.Model) { ticker := ConduitInstance.GetTicker(props.Model, group)
if ticker == nil || ticker.IsEmpty() {
return fmt.Errorf("cannot find channel for model %s", props.Model) return fmt.Errorf("cannot find channel for model %s", props.Model)
} }
ticker := ConduitInstance.GetTicker(props.Model)
var err error var err error
for !ticker.IsDone() { for !ticker.IsDone() {
if channel := ticker.Next(); channel != nil { if channel := ticker.Next(); channel != nil {
@ -23,10 +21,10 @@ func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error {
return nil return nil
} }
logs.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.Model, channel.GetName())) globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.Model, channel.GetName()))
} }
} }
logs.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model)) globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
return err return err
} }

View File

@ -88,12 +88,15 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
} }
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
err := channel.NewChatRequest(&adapter.ChatProps{ err := channel.NewChatRequest(
auth.GetGroup(db, user),
&adapter.ChatProps{
Model: model, Model: model,
Message: segment, Message: segment,
Plan: plan, Plan: plan,
Buffer: *buffer, Buffer: *buffer,
}, func(data string) error { },
func(data string) error {
if signal := conn.PeekWithType(StopType); signal != nil { if signal := conn.PeekWithType(StopType); signal != nil {
// stop signal from client // stop signal from client
return fmt.Errorf("signal") return fmt.Errorf("signal")
@ -104,7 +107,8 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
End: false, End: false,
Plan: plan, Plan: plan,
}) })
}) },
)
admin.AnalysisRequest(model, buffer, err) admin.AnalysisRequest(model, buffer, err)
if err != nil && err.Error() != "signal" { if err != nil && err.Error() != "signal" {

View File

@ -40,15 +40,19 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
} }
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
err := channel.NewChatRequest(&adapter.ChatProps{ err := channel.NewChatRequest(
auth.GetGroup(db, user),
&adapter.ChatProps{
Model: model, Model: model,
Plan: plan, Plan: plan,
Message: segment, Message: segment,
Buffer: *buffer, Buffer: *buffer,
}, func(resp string) error { },
func(resp string) error {
buffer.Write(resp) buffer.Write(resp)
return nil return nil
}) },
)
admin.AnalysisRequest(model, buffer, err) admin.AnalysisRequest(model, buffer, err)
if err != nil { if err != nil {

View File

@ -182,7 +182,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
cache := utils.GetCacheFromContext(c) cache := utils.GetCacheFromContext(c)
buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model)) buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model))
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { err := channel.NewChatRequest(auth.GetGroup(db, user), GetProps(form, buffer, plan), func(data string) error {
buffer.Write(data) buffer.Write(data)
return nil return nil
}) })
@ -251,7 +251,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
go func() { go func() {
buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model)) buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model))
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { err := channel.NewChatRequest(auth.GetGroup(db, user), GetProps(form, buffer, plan), func(data string) error {
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil) partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
return nil return nil
}) })