release: channel feature

This commit is contained in:
Zhang Minghan 2023-12-03 11:13:58 +08:00
parent e6ca4c7015
commit b2c2c2d92d
11 changed files with 84 additions and 42 deletions

View File

@ -31,7 +31,7 @@ export class Conversation {
if (id === -1 && this.idx === -1) { if (id === -1 && this.idx === -1) {
sharingEvent.bind(({ refer, data }) => { sharingEvent.bind(({ refer, data }) => {
console.log( console.debug(
`[conversation] load from sharing event (ref: ${refer}, length: ${data.length})`, `[conversation] load from sharing event (ref: ${refer}, length: ${data.length})`,
); );

View File

@ -192,17 +192,21 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
}, [edit.models]); }, [edit.models]);
const enabled = useMemo(() => validator(edit), [edit]); const enabled = useMemo(() => validator(edit), [edit]);
function close(clear?: boolean) {
if (clear) dispatch({ type: "clear" });
setEnabled(false);
}
async function post() { async function post() {
const data = handler(edit); const data = handler(edit);
console.debug(`[channel] preflight channel data`, data); console.debug(`[channel] preflight channel data`, data);
const resp = const resp =
id === -1 ? await createChannel(data) : await updateChannel(id, data); id === -1 ? await createChannel(data) : await updateChannel(id, data);
toastState(toast, t, resp as ChannelCommonResponse); toastState(toast, t, resp as ChannelCommonResponse, true);
if (resp.status) { if (resp.status) {
dispatch({ type: "clear" }); close(true);
setEnabled(false);
} }
} }
@ -211,7 +215,6 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
else { else {
const resp = await getChannel(id); const resp = await getChannel(id);
toastState(toast, t, resp as ChannelCommonResponse); toastState(toast, t, resp as ChannelCommonResponse);
console.log(resp);
if (resp.data) dispatch({ type: "set", value: resp.data }); if (resp.data) dispatch({ type: "set", value: resp.data });
} }
}, [id]); }, [id]);
@ -309,7 +312,8 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
</DropdownMenu> </DropdownMenu>
<CustomAction <CustomAction
onPost={(model) => { onPost={(model) => {
dispatch({ type: "add-model", value: model }); const models = model.split(" ");
dispatch({ type: "add-models", value: models });
}} }}
/> />
<Button <Button
@ -406,7 +410,7 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
</div> </div>
<div className={`mt-4 flex flex-row w-full h-max pr-2`}> <div className={`mt-4 flex flex-row w-full h-max pr-2`}>
<div className={`grow`} /> <div className={`grow`} />
<Button variant={`outline`} onClick={() => setEnabled(false)}> <Button variant={`outline`} onClick={() => close()}>
{t("cancel")} {t("cancel")}
</Button> </Button>
<Button className={`ml-2`} onClick={post} disabled={!enabled}> <Button className={`ml-2`} onClick={post} disabled={!enabled}>

View File

@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge.tsx";
import { Check, Plus, RotateCw, Settings2, Trash, X } from "lucide-react"; import { Check, Plus, RotateCw, Settings2, Trash, X } from "lucide-react";
import { Button } from "@/components/ui/button.tsx"; import { Button } from "@/components/ui/button.tsx";
import OperationAction from "@/components/OperationAction.tsx"; import OperationAction from "@/components/OperationAction.tsx";
import { useMemo, useState } from "react"; import { useEffect, useMemo, useState } from "react";
import { Channel, getChannelType, toastState } from "@/admin/channel.ts"; import { Channel, getChannelType, toastState } from "@/admin/channel.ts";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { useEffectAsync } from "@/utils/hook.ts"; import { useEffectAsync } from "@/utils/hook.ts";
@ -54,6 +54,10 @@ function ChannelTable({ display, setId, setEnabled }: ChannelTableProps) {
useEffectAsync(refresh, []); useEffectAsync(refresh, []);
useEffectAsync(refresh, [display]); useEffectAsync(refresh, [display]);
useEffect(() => {
if (display) setId(-1);
}, [display]);
return ( return (
display && ( display && (
<div className={`channel-table`}> <div className={`channel-table`}>

View File

@ -8,7 +8,7 @@ import {
} from "@/utils/env.ts"; } from "@/utils/env.ts";
import { getMemory } from "@/utils/memory.ts"; import { getMemory } from "@/utils/memory.ts";
export const version = "3.7.0"; export const version = "3.7.1";
export const dev: boolean = getDev(); export const dev: boolean = getDev();
export const deploy: boolean = true; export const deploy: boolean = true;
export let rest_api: string = getRestApi(deploy); export let rest_api: string = getRestApi(deploy);

View File

@ -398,7 +398,7 @@ const resources = {
create: "创建渠道", create: "创建渠道",
"search-model": "搜索模型", "search-model": "搜索模型",
"fill-template-models": "填入模板模型 ({{number}} 个)", "fill-template-models": "填入模板模型 ({{number}} 个)",
"add-custom-model": "添加自定义模型", "add-custom-model": "添加自定义模型 (多个模型用空格分隔)",
"add-model": "添加模型", "add-model": "添加模型",
"clear-models": "清空全部模型", "clear-models": "清空全部模型",
}, },
@ -818,7 +818,8 @@ const resources = {
create: "Create Channel", create: "Create Channel",
"search-model": "Search Model", "search-model": "Search Model",
"fill-template-models": "Fill Template Models ({{number}})", "fill-template-models": "Fill Template Models ({{number}})",
"add-custom-model": "Add Custom Model", "add-custom-model":
"Add Custom Model (Multiple models are separated by spaces)",
"add-model": "Add Model", "add-model": "Add Model",
"clear-models": "Clear All Models", "clear-models": "Clear All Models",
}, },
@ -1241,7 +1242,8 @@ const resources = {
create: "Создать канал", create: "Создать канал",
"search-model": "Поиск по имени модели", "search-model": "Поиск по имени модели",
"fill-template-models": "Заполнить шаблонные модели ({{number}})", "fill-template-models": "Заполнить шаблонные модели ({{number}})",
"add-custom-model": "Добавить пользовательскую модель", "add-custom-model":
"Добавить пользовательскую модель (несколько моделей разделяются пробелами)",
"add-model": "Добавить модель", "add-model": "Добавить модель",
"clear-models": "Очистить все модели", "clear-models": "Очистить все модели",
}, },

View File

@ -80,7 +80,7 @@ func (c *Channel) GetMapper() string {
func (c *Channel) GetReflect() map[string]string { func (c *Channel) GetReflect() map[string]string {
if c.Reflect == nil { if c.Reflect == nil {
var reflect map[string]string reflect := make(map[string]string)
arr := strings.Split(c.GetMapper(), "\n") arr := strings.Split(c.GetMapper(), "\n")
for _, item := range arr { for _, item := range arr {
pair := strings.Split(item, ">") pair := strings.Split(item, ">")
@ -126,7 +126,6 @@ func (c *Channel) GetHitModels() []string {
c.HitModels = &res c.HitModels = &res
} }
return *c.HitModels return *c.HitModels
} }

View File

@ -34,7 +34,7 @@ func (m *Manager) Load() {
// init support models // init support models
m.Models = []string{} m.Models = []string{}
for _, channel := range m.GetActiveSequence() { for _, channel := range m.GetActiveSequence() {
for _, model := range channel.GetModels() { for _, model := range channel.GetHitModels() {
if !utils.Contains(model, m.Models) { if !utils.Contains(model, m.Models) {
m.Models = append(m.Models, model) m.Models = append(m.Models, model)
} }
@ -46,7 +46,7 @@ func (m *Manager) Load() {
for _, model := range m.Models { for _, model := range m.Models {
var seq Sequence var seq Sequence
for _, channel := range m.GetActiveSequence() { for _, channel := range m.GetActiveSequence() {
if utils.Contains(model, channel.GetModels()) { if channel.IsHit(model) {
seq = append(seq, channel) seq = append(seq, channel)
} }
} }

View File

@ -7,7 +7,7 @@ func (s *Sequence) Len() int {
} }
func (s *Sequence) Less(i, j int) bool { func (s *Sequence) Less(i, j int) bool {
return (*s)[i].GetPriority() < (*s)[j].GetPriority() return (*s)[i].GetPriority() > (*s)[j].GetPriority()
} }
func (s *Sequence) Swap(i, j int) { func (s *Sequence) Swap(i, j int) {

View File

@ -12,8 +12,8 @@ type Channel struct {
Endpoint string `json:"endpoint" mapstructure:"endpoint"` Endpoint string `json:"endpoint" mapstructure:"endpoint"`
Mapper string `json:"mapper" mapstructure:"mapper"` Mapper string `json:"mapper" mapstructure:"mapper"`
State bool `json:"state" mapstructure:"state"` State bool `json:"state" mapstructure:"state"`
Reflect *map[string]string `json:"reflect" mapstructure:"reflect"` Reflect *map[string]string `json:"-"`
HitModels *[]string `json:"hit_models" mapstructure:"hit_models"` HitModels *[]string `json:"-"`
} }
type Sequence []*Channel type Sequence []*Channel

View File

@ -68,37 +68,58 @@ type TranshipmentStreamResponse struct {
Choices []ChoiceDelta `json:"choices"` Choices []ChoiceDelta `json:"choices"`
Usage Usage `json:"usage"` Usage Usage `json:"usage"`
Quota *float32 `json:"quota,omitempty"` Quota *float32 `json:"quota,omitempty"`
Error error `json:"error,omitempty"`
}
type TranshipmentErrorResponse struct {
Error TranshipmentError `json:"error"`
}
type TranshipmentError struct {
Message string `json:"message"`
Type string `json:"type"`
} }
func ModelAPI(c *gin.Context) { func ModelAPI(c *gin.Context) {
c.JSON(http.StatusOK, channel.ManagerInstance.GetModels()) c.JSON(http.StatusOK, channel.ManagerInstance.GetModels())
} }
func sendErrorResponse(c *gin.Context, err error, types ...string) {
var errType string
if len(types) > 0 {
errType = types[0]
} else {
errType = "chatnio_api_error"
}
c.JSON(http.StatusServiceUnavailable, TranshipmentErrorResponse{
Error: TranshipmentError{
Message: err.Error(),
Type: errType,
},
})
}
func abortWithErrorResponse(c *gin.Context, err error, types ...string) {
sendErrorResponse(c, err, types...)
c.Abort()
}
func TranshipmentAPI(c *gin.Context) { func TranshipmentAPI(c *gin.Context) {
username := utils.GetUserFromContext(c) username := utils.GetUserFromContext(c)
if username == "" { if username == "" {
c.AbortWithStatusJSON(403, gin.H{ abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
"code": 403,
"message": "Access denied. Please provide correct api key.",
})
return return
} }
if utils.GetAgentFromContext(c) != "api" { if utils.GetAgentFromContext(c) != "api" {
c.AbortWithStatusJSON(403, gin.H{ abortWithErrorResponse(c, fmt.Errorf("access denied for invalid agent"), "authentication_error")
"code": 403,
"message": "Access denied. Please provide correct api key.",
})
return return
} }
var form TranshipmentForm var form TranshipmentForm
if err := c.ShouldBindJSON(&form); err != nil { if err := c.ShouldBindJSON(&form); err != nil {
c.JSON(400, gin.H{ abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
"status": false,
"error": "invalid request body",
"reason": err.Error(),
})
return return
} }
@ -124,11 +145,7 @@ func TranshipmentAPI(c *gin.Context) {
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model) check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model)
if !check { if !check {
c.JSON(http.StatusForbidden, gin.H{ sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error")
"status": false,
"error": "quota exceeded",
"reason": "not enough quota to use this model",
})
return return
} }
@ -171,6 +188,9 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
if err != nil { if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model) auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP())) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP()))
sendErrorResponse(c, err)
return
} }
CollectQuota(c, user, buffer, plan) CollectQuota(c, user, buffer, plan)
@ -195,7 +215,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
}) })
} }
func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool) TranshipmentStreamResponse { func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool, err error) TranshipmentStreamResponse {
return TranshipmentStreamResponse{ return TranshipmentStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", id), Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@ -217,6 +237,7 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
TotalTokens: utils.MultiF(end, func() int { return buffer.CountToken() }, 0), TotalTokens: utils.MultiF(end, func() int { return buffer.CountToken() }, 0),
}, },
Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())), Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())),
Error: err,
} }
} }
@ -228,20 +249,20 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
go func() { go func() {
buffer := utils.NewBuffer(form.Model, form.Messages) buffer := utils.NewBuffer(form.Model, form.Messages)
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false) partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
return nil return nil
}) })
admin.AnalysisRequest(form.Model, buffer, err) admin.AnalysisRequest(form.Model, buffer, err)
if err != nil { if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model) auth.RevertSubscriptionUsage(db, cache, user, form.Model)
partial <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err.Error(), form.Model, c.ClientIP()))
CollectQuota(c, user, buffer, plan) partial <- getStreamTranshipmentForm(id, created, form, err.Error(), buffer, true, err)
close(partial) close(partial)
return return
} }
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true) partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
CollectQuota(c, user, buffer, plan) CollectQuota(c, user, buffer, plan)
close(partial) close(partial)
return return
@ -249,6 +270,11 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
if resp, ok := <-partial; ok { if resp, ok := <-partial; ok {
if resp.Error != nil {
sendErrorResponse(c, resp.Error)
return false
}
c.Render(-1, utils.NewEvent(resp)) c.Render(-1, utils.NewEvent(resp))
return true return true
} }

View File

@ -232,3 +232,10 @@ func GetError(err error) string {
} }
return "" return ""
} }
func GetIndexSafe[T any](arr []T, index int) *T {
if index >= len(arr) {
return nil
}
return &arr[index]
}