mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
feat: channel group
This commit is contained in:
parent
b0174a5db3
commit
5affcbb4db
@ -23,6 +23,7 @@ import (
|
||||
type RequestProps struct {
|
||||
MaxRetries *int
|
||||
Current int
|
||||
Group string
|
||||
}
|
||||
|
||||
type ChatProps struct {
|
||||
|
@ -36,7 +36,8 @@ func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.H
|
||||
}
|
||||
|
||||
if props.Current < retries {
|
||||
globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, err.Error()))
|
||||
content := strings.Replace(err.Error(), "\n", "", -1)
|
||||
globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, content))
|
||||
return NewChatRequest(conf, props, hook)
|
||||
}
|
||||
}
|
||||
|
@ -72,14 +72,20 @@ func GenerateAPI(c *gin.Context) {
|
||||
}
|
||||
|
||||
var instance *utils.Buffer
|
||||
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, plan, func(buffer *utils.Buffer, data string) {
|
||||
hash, err := CreateGenerationWithCache(
|
||||
auth.GetGroup(db, user),
|
||||
form.Model,
|
||||
form.Prompt,
|
||||
plan,
|
||||
func(buffer *utils.Buffer, data string) {
|
||||
instance = buffer
|
||||
conn.Send(globals.GenerationSegmentResponse{
|
||||
End: false,
|
||||
Message: data,
|
||||
Quota: buffer.GetQuota(),
|
||||
})
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
if instance != nil && !plan && instance.GetQuota() > 0 && user != nil {
|
||||
user.UseQuota(db, instance.GetQuota())
|
||||
|
@ -6,10 +6,10 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) {
|
||||
func CreateGenerationWithCache(group, model, prompt string, enableReverse bool, hook func(buffer *utils.Buffer, data string)) (string, error) {
|
||||
hash, path := GetFolderByHash(model, prompt)
|
||||
if !utils.Exists(path) {
|
||||
if err := CreateGeneration(model, prompt, path, enableReverse, hook); err != nil {
|
||||
if err := CreateGeneration(group, model, prompt, path, enableReverse, hook); err != nil {
|
||||
globals.Info(fmt.Sprintf("[project] error during generation %s (model %s): %s", prompt, model, err.Error()))
|
||||
return "", fmt.Errorf("error during generate project: %s", err.Error())
|
||||
}
|
||||
|
@ -13,11 +13,11 @@ type ProjectResult struct {
|
||||
Result map[string]interface{} `json:"result"`
|
||||
}
|
||||
|
||||
func CreateGeneration(model string, prompt string, path string, plan bool, hook func(buffer *utils.Buffer, data string)) error {
|
||||
func CreateGeneration(group, model, prompt, path string, plan bool, hook func(buffer *utils.Buffer, data string)) error {
|
||||
message := GenerateMessage(prompt)
|
||||
buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model))
|
||||
|
||||
err := channel.NewChatRequest(&adapter.ChatProps{
|
||||
err := channel.NewChatRequest(group, &adapter.ChatProps{
|
||||
Model: model,
|
||||
Message: message,
|
||||
Plan: plan,
|
||||
|
@ -10,6 +10,7 @@ export type Channel = {
|
||||
endpoint: string;
|
||||
mapper: string;
|
||||
state: boolean;
|
||||
group?: string[];
|
||||
};
|
||||
|
||||
export type ChannelInfo = {
|
||||
@ -164,6 +165,14 @@ export const channelModels: string[] = Object.values(ChannelInfos).flatMap(
|
||||
(info) => info.models,
|
||||
);
|
||||
|
||||
export const channelGroups: string[] = [
|
||||
"anonymous",
|
||||
"normal",
|
||||
"basic",
|
||||
"standard",
|
||||
"pro",
|
||||
];
|
||||
|
||||
export function getChannelInfo(type?: string): ChannelInfo {
|
||||
if (type && type in ChannelInfos) return ChannelInfos[type];
|
||||
return ChannelInfos.openai;
|
||||
|
@ -75,6 +75,7 @@
|
||||
border-radius: var(--radius);
|
||||
transition: .25s;
|
||||
height: max-content;
|
||||
white-space: break-spaces;
|
||||
|
||||
&:hover {
|
||||
border-color: hsl(var(--border-hover));
|
||||
@ -87,6 +88,7 @@
|
||||
margin-left: 0.5rem;
|
||||
color: hsl(var(--text-secondary));
|
||||
transition: .25s;
|
||||
flex-shrink: 0;
|
||||
|
||||
&:hover {
|
||||
color: hsl(var(--text-primary));
|
||||
@ -101,6 +103,12 @@
|
||||
width: 100%;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.5rem;
|
||||
|
||||
@media (max-width: 620px) {
|
||||
& > * {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.channel-description {
|
||||
|
@ -1,3 +1,7 @@
|
||||
.gold-text {
|
||||
color: hsl(var(--gold)) !important;
|
||||
}
|
||||
|
||||
.select-group {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
|
@ -10,6 +10,7 @@ import {
|
||||
} from "@/components/ui/select.tsx";
|
||||
import {
|
||||
Channel,
|
||||
channelGroups,
|
||||
channelModels,
|
||||
ChannelTypes,
|
||||
getChannelInfo,
|
||||
@ -54,6 +55,7 @@ const initialState: Channel = {
|
||||
endpoint: getChannelInfo().endpoint,
|
||||
mapper: "",
|
||||
state: true,
|
||||
group: [],
|
||||
};
|
||||
|
||||
type CustomActionProps = {
|
||||
@ -133,6 +135,18 @@ function reducer(state: Channel, action: any) {
|
||||
return { ...state, retry: action.value };
|
||||
case "clear":
|
||||
return { ...initialState };
|
||||
case "add-group":
|
||||
return {
|
||||
...state,
|
||||
group: state.group ? [...state.group, action.value] : [action.value],
|
||||
};
|
||||
case "remove-group":
|
||||
return {
|
||||
...state,
|
||||
group: state.group
|
||||
? state.group.filter((group) => group !== action.value)
|
||||
: [],
|
||||
};
|
||||
case "set":
|
||||
return { ...state, ...action.value };
|
||||
default:
|
||||
@ -171,6 +185,9 @@ function handler(data: Channel): Channel {
|
||||
);
|
||||
})
|
||||
.join("\n");
|
||||
data.group = data.group
|
||||
? data.group.filter((group) => group.trim() !== "")
|
||||
: [];
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -191,6 +208,13 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
|
||||
}, [edit.models]);
|
||||
const enabled = useMemo(() => validator(edit), [edit]);
|
||||
|
||||
const unusedGroups = useMemo(() => {
|
||||
if (!edit.group) return channelGroups;
|
||||
return channelGroups.filter(
|
||||
(group) => !edit.group.includes(group) && group !== "",
|
||||
);
|
||||
}, [edit.group]);
|
||||
|
||||
function close(clear?: boolean) {
|
||||
if (clear) dispatch({ type: "clear" });
|
||||
setEnabled(false);
|
||||
@ -409,6 +433,69 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<div className={`channel-row`}>
|
||||
<div className={`channel-content`}>
|
||||
{t("admin.channels.group")}
|
||||
<Tips content={t("admin.channels.group-tip")} />
|
||||
</div>
|
||||
<div className={`flex flex-row gap-2 items-center`}>
|
||||
<div
|
||||
className={`channel-model-wrapper`}
|
||||
style={{ minHeight: "2.5rem" }}
|
||||
>
|
||||
{(edit.group || []).map((item: string, idx: number) => (
|
||||
<div className={`channel-model-item`} key={idx}>
|
||||
{item}
|
||||
<X
|
||||
className={`remove-action`}
|
||||
onClick={() =>
|
||||
dispatch({ type: "remove-group", value: item })
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
disabled={unusedGroups.length === 0}
|
||||
className={`h-full`}
|
||||
>
|
||||
<Plus className={`w-4 h-4 mr-1`} />
|
||||
{t("add")}
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align={`end`} asChild>
|
||||
<Command>
|
||||
<CommandList
|
||||
className={
|
||||
unusedGroups.length ? `thin-scrollbar` : `no-scrollbar`
|
||||
}
|
||||
>
|
||||
{unusedGroups.length === 0 ? (
|
||||
<p className={`p-2 text-center`}>
|
||||
{t("conversation.empty")}
|
||||
</p>
|
||||
) : (
|
||||
unusedGroups.map((item, idx) => (
|
||||
<CommandItem
|
||||
key={idx}
|
||||
value={item}
|
||||
onSelect={() =>
|
||||
dispatch({ type: "add-group", value: item })
|
||||
}
|
||||
className={`px-2 ${idx > 1 ? "gold-text" : ""}`}
|
||||
>
|
||||
{item}
|
||||
</CommandItem>
|
||||
))
|
||||
)}
|
||||
</CommandList>
|
||||
</Command>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className={`mt-4 flex flex-row w-full h-max pr-2 items-center`}>
|
||||
<div className={`object-id`}>
|
||||
|
@ -397,6 +397,9 @@ const resources = {
|
||||
"请输入模型映射,一行一个,格式: model>model\n" +
|
||||
"前者为请求的模型,后者为映射的模型(需要在模型中存在),中间用 > 分隔\n" +
|
||||
"格式前加!表示原模型不包含在此渠道的可用范围内,如: !gpt-4-slow>gpt-4,那么 gpt-4 将不会被涵盖在此渠道的可请求模型中",
|
||||
group: "用户分组",
|
||||
"group-tip":
|
||||
"用户分组,未包含的分组将不包含在此渠道的可用范围内 (分组为空时,所有用户都可以使用此渠道)",
|
||||
state: "状态",
|
||||
action: "操作",
|
||||
edit: "编辑渠道",
|
||||
@ -862,6 +865,9 @@ const resources = {
|
||||
"Please enter the model mapper, one line each, format: model>model\n" +
|
||||
"The former is the requested model, and the latter is the mapped model (which needs to exist in the model), separated by > in the middle\n" +
|
||||
"The format is preceded by! Indicates that the original model is not included in the available range of this channel, such as: !gpt-4-slow>gpt-4, then gpt-4 will not be covered in the available models that can be requested in this channel",
|
||||
group: "User Group",
|
||||
"group-tip":
|
||||
"User group, the group that is not included will not be included in the available range of this channel (when the group is empty, all users can use this channel)",
|
||||
state: "State",
|
||||
action: "Action",
|
||||
edit: "Edit Channel",
|
||||
@ -1331,6 +1337,9 @@ const resources = {
|
||||
"Введите модельный маппер, по одной строке, формат: model>model\n" +
|
||||
"Первая модель - запрошенная модель, вторая модель - отображаемая модель (которая должна существовать в модели), разделенная > посередине\n" +
|
||||
"Формат предшествует! Означает, что исходная модель не включена в доступный диапазон этого канала, например: !gpt-4-slow>gpt-4, тогда gpt-4 не будет охвачен в доступных моделях, которые можно запросить в этом канале",
|
||||
group: "Группа пользователей",
|
||||
"group-tip":
|
||||
"Группа пользователей, группа, которая не включена, не будет включена в доступный диапазон этого канала (когда группа пуста, все пользователи могут использовать этот канал)",
|
||||
state: "Статус",
|
||||
action: "Действие",
|
||||
edit: "Редактировать канал",
|
||||
|
12
app/src/translator/index.ts
Normal file
12
app/src/translator/index.ts
Normal file
@ -0,0 +1,12 @@
|
||||
import { Plugin } from "vite";
|
||||
import path from "path";
|
||||
|
||||
export function createTranslationPlugin(): Plugin {
|
||||
return {
|
||||
name: "translate-plugin",
|
||||
apply: "build",
|
||||
configResolved(config) {
|
||||
const dir = path.resolve(config.root, "src");
|
||||
},
|
||||
};
|
||||
}
|
@ -6,5 +6,5 @@
|
||||
"moduleResolution": "bundler",
|
||||
"allowSyntheticDefaultImports": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
"include": ["vite.config.ts", "src/translator"]
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react-swc'
|
||||
import path from "path"
|
||||
import { createHtmlPlugin } from 'vite-plugin-html'
|
||||
import { createTranslationPlugin } from "./src/translator";
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
@ -10,6 +11,7 @@ export default defineConfig({
|
||||
createHtmlPlugin({
|
||||
minify: true,
|
||||
}),
|
||||
createTranslationPlugin(),
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
@ -62,3 +63,23 @@ func IsUserExist(db *sql.DB, username string) bool {
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func GetGroup(db *sql.DB, user *User) string {
|
||||
if user == nil {
|
||||
return globals.AnonymousType
|
||||
}
|
||||
|
||||
level := user.GetSubscriptionLevel(db)
|
||||
switch level {
|
||||
case 0:
|
||||
return globals.NormalType
|
||||
case 1:
|
||||
return globals.BasicType
|
||||
case 2:
|
||||
return globals.StandardType
|
||||
case 3:
|
||||
return globals.ProType
|
||||
default:
|
||||
return globals.NormalType
|
||||
}
|
||||
}
|
||||
|
@ -158,6 +158,18 @@ func (c *Channel) GetState() bool {
|
||||
return c.State
|
||||
}
|
||||
|
||||
func (c *Channel) GetGroup() []string {
|
||||
return c.Group
|
||||
}
|
||||
|
||||
func (c *Channel) IsHitGroup(group string) bool {
|
||||
if len(c.GetGroup()) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return utils.Contains(group, c.GetGroup())
|
||||
}
|
||||
|
||||
func (c *Channel) IsHit(model string) bool {
|
||||
return utils.Contains(model, c.GetHitModels())
|
||||
}
|
||||
|
@ -99,10 +99,12 @@ func (m *Manager) HasChannel(model string) bool {
|
||||
return utils.Contains(model, m.Models)
|
||||
}
|
||||
|
||||
func (m *Manager) GetTicker(model string) *Ticker {
|
||||
return &Ticker{
|
||||
Sequence: m.HitSequence(model),
|
||||
func (m *Manager) GetTicker(model, group string) *Ticker {
|
||||
if !m.HasChannel(model) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return NewTicker(m.HitSequence(model), group)
|
||||
}
|
||||
|
||||
func (m *Manager) Len() int {
|
||||
|
@ -2,6 +2,21 @@ package channel
|
||||
|
||||
import "chat/utils"
|
||||
|
||||
func NewTicker(seq Sequence, group string) *Ticker {
|
||||
stack := make(Sequence, 0)
|
||||
for _, channel := range seq {
|
||||
if channel.IsHitGroup(group) {
|
||||
stack = append(stack, channel)
|
||||
}
|
||||
}
|
||||
|
||||
stack.Sort()
|
||||
|
||||
return &Ticker{
|
||||
Sequence: stack,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Ticker) GetChannelByPriority(priority int) *Channel {
|
||||
var stack Sequence
|
||||
|
||||
@ -72,10 +87,10 @@ func (t *Ticker) SkipPriority(priority int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Ticker) Skip() {
|
||||
t.Cursor++
|
||||
}
|
||||
|
||||
func (t *Ticker) IsDone() bool {
|
||||
return t.Cursor >= len(t.Sequence)
|
||||
}
|
||||
|
||||
func (t *Ticker) IsEmpty() bool {
|
||||
return len(t.Sequence) == 0
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ type Channel struct {
|
||||
Endpoint string `json:"endpoint" mapstructure:"endpoint"`
|
||||
Mapper string `json:"mapper" mapstructure:"mapper"`
|
||||
State bool `json:"state" mapstructure:"state"`
|
||||
Group []string `json:"group" mapstructure:"group"`
|
||||
Reflect *map[string]string `json:"-"`
|
||||
HitModels *[]string `json:"-"`
|
||||
ExcludeModels *[]string `json:"-"`
|
||||
|
@ -5,16 +5,14 @@ import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/cloudwego/hertz/cmd/hz/util/logs"
|
||||
)
|
||||
|
||||
func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error {
|
||||
if !ConduitInstance.HasChannel(props.Model) {
|
||||
func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error {
|
||||
ticker := ConduitInstance.GetTicker(props.Model, group)
|
||||
if ticker == nil || ticker.IsEmpty() {
|
||||
return fmt.Errorf("cannot find channel for model %s", props.Model)
|
||||
}
|
||||
|
||||
ticker := ConduitInstance.GetTicker(props.Model)
|
||||
|
||||
var err error
|
||||
for !ticker.IsDone() {
|
||||
if channel := ticker.Next(); channel != nil {
|
||||
@ -23,10 +21,10 @@ func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
logs.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.Model, channel.GetName()))
|
||||
globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.Model, channel.GetName()))
|
||||
}
|
||||
}
|
||||
|
||||
logs.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
|
||||
globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model))
|
||||
return err
|
||||
}
|
||||
|
@ -88,12 +88,15 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
||||
err := channel.NewChatRequest(&adapter.ChatProps{
|
||||
err := channel.NewChatRequest(
|
||||
auth.GetGroup(db, user),
|
||||
&adapter.ChatProps{
|
||||
Model: model,
|
||||
Message: segment,
|
||||
Plan: plan,
|
||||
Buffer: *buffer,
|
||||
}, func(data string) error {
|
||||
},
|
||||
func(data string) error {
|
||||
if signal := conn.PeekWithType(StopType); signal != nil {
|
||||
// stop signal from client
|
||||
return fmt.Errorf("signal")
|
||||
@ -104,7 +107,8 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
End: false,
|
||||
Plan: plan,
|
||||
})
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil && err.Error() != "signal" {
|
||||
|
@ -40,15 +40,19 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
||||
err := channel.NewChatRequest(&adapter.ChatProps{
|
||||
err := channel.NewChatRequest(
|
||||
auth.GetGroup(db, user),
|
||||
&adapter.ChatProps{
|
||||
Model: model,
|
||||
Plan: plan,
|
||||
Message: segment,
|
||||
Buffer: *buffer,
|
||||
}, func(resp string) error {
|
||||
},
|
||||
func(resp string) error {
|
||||
buffer.Write(resp)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil {
|
||||
|
@ -182,7 +182,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
|
||||
err := channel.NewChatRequest(auth.GetGroup(db, user), GetProps(form, buffer, plan), func(data string) error {
|
||||
buffer.Write(data)
|
||||
return nil
|
||||
})
|
||||
@ -251,7 +251,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
|
||||
|
||||
go func() {
|
||||
buffer := utils.NewBuffer(form.Model, form.Messages, channel.ChargeInstance.GetCharge(form.Model))
|
||||
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
|
||||
err := channel.NewChatRequest(auth.GetGroup(db, user), GetProps(form, buffer, plan), func(data string) error {
|
||||
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
|
||||
return nil
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user