From 8e3a424e60004c010105a74c17e7b98e8f8fc15f Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Fri, 1 Dec 2023 22:48:25 +0800 Subject: [PATCH] update: channel worker, channel sequence and ticker --- .gitignore | 1 + adapter/adapter.go | 154 +++++++++++++++++----------------- adapter/chatgpt/struct.go | 3 +- adapter/claude/struct.go | 3 +- adapter/dashscope/struct.go | 3 +- adapter/oneapi/globals.go | 26 ------ adapter/oneapi/struct.go | 5 -- adapter/palm2/struct.go | 3 +- adapter/request.go | 23 ++--- addition/generation/prompt.go | 3 +- channel/channel.go | 25 +++++- channel/manager.go | 153 +++++++++++++++++++++++++++++++++ channel/sequence.go | 29 +++++++ channel/ticker.go | 81 ++++++++++++++++++ channel/types.go | 36 ++++---- channel/worker.go | 29 +++++++ globals/constant.go | 28 +++---- globals/interface.go | 10 +++ main.go | 2 + manager/chat.go | 3 +- manager/completions.go | 3 +- manager/transhipment.go | 21 ++--- utils/base.go | 7 ++ utils/char.go | 7 +- utils/key.go | 12 --- 25 files changed, 478 insertions(+), 192 deletions(-) delete mode 100644 adapter/oneapi/globals.go create mode 100644 channel/manager.go create mode 100644 channel/sequence.go create mode 100644 channel/ticker.go create mode 100644 channel/worker.go create mode 100644 globals/interface.go delete mode 100644 utils/key.go diff --git a/.gitignore b/.gitignore index 206b211..6be5922 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ node_modules .vscode .idea config.yaml +config.dev.yaml addition/generation/data/* !addition/generation/data/.gitkeep diff --git a/adapter/adapter.go b/adapter/adapter.go index 5b4abc0..8ab54b9 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -17,11 +17,9 @@ import ( "chat/adapter/zhipuai" "chat/globals" "chat/utils" + "fmt" ) -var defaultMaxRetries = 3 -var midjourneyMaxRetries = 10 - type RequestProps struct { MaxRetries *int Current int @@ -46,36 +44,21 @@ type ChatProps struct { Buffer utils.Buffer } -func createChatRequest(props *ChatProps, hook globals.Hook) error { - if oneapi.IsHit(props.Model) { - return oneapi.NewChatInstanceFromConfig().CreateStreamChatRequest(&oneapi.ChatProps{ - Model: props.Model, - Message: props.Message, - Token: utils.Multi( - props.Token == 0, - utils.Multi(globals.IsGPT4Model(props.Model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)), - &props.Token, - ), - PresencePenalty: props.PresencePenalty, - FrequencyPenalty: props.FrequencyPenalty, - Temperature: props.Temperature, - TopP: props.TopP, - Tools: props.Tools, - ToolChoice: props.ToolChoice, - Buffer: props.Buffer, - }, hook) +func createChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error { + model := conf.GetModelReflect(props.Model) - } else if globals.IsChatGPTModel(props.Model) { + switch conf.GetType() { + case globals.OpenAIChannelType: instance := chatgpt.NewChatInstanceFromModel(&chatgpt.InstanceProps{ - Model: props.Model, + Model: model, Plan: props.Plan, }) return instance.CreateStreamChatRequest(&chatgpt.ChatProps{ - Model: props.Model, + Model: model, Message: props.Message, Token: utils.Multi( props.Token == 0, - utils.Multi(globals.IsGPT4Model(props.Model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)), + utils.Multi(globals.IsGPT4Model(model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)), &props.Token, ), PresencePenalty: props.PresencePenalty, @@ -87,9 +70,9 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error { Buffer: props.Buffer, }, hook) - } else if globals.IsClaudeModel(props.Model) { + case globals.ClaudeChannelType: return claude.NewChatInstanceFromConfig().CreateStreamChatRequest(&claude.ChatProps{ - Model: props.Model, + Model: model, Message: props.Message, Token: utils.Multi(props.Token == 0, 50000, props.Token), TopP: props.TopP, @@ -97,9 +80,26 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error { Temperature: props.Temperature, }, hook) - } else if globals.IsSparkDeskModel(props.Model) { - return sparkdesk.NewChatInstance(props.Model).CreateStreamChatRequest(&sparkdesk.ChatProps{ - Model: props.Model, + case globals.SlackChannelType: + return slack.NewChatInstanceFromConfig().CreateStreamChatRequest(&slack.ChatProps{ + Message: props.Message, + }, hook) + + case globals.BingChannelType: + return bing.NewChatInstanceFromConfig().CreateStreamChatRequest(&bing.ChatProps{ + Model: model, + Message: props.Message, + }, hook) + + case globals.PalmChannelType: + return palm2.NewChatInstanceFromConfig().CreateStreamChatRequest(&palm2.ChatProps{ + Model: model, + Message: props.Message, + }, hook) + + case globals.SparkdeskChannelType: + return sparkdesk.NewChatInstance(model).CreateStreamChatRequest(&sparkdesk.ChatProps{ + Model: model, Message: props.Message, Token: utils.Multi(props.Token == 0, nil, utils.ToPtr(props.Token)), Temperature: props.Temperature, @@ -108,34 +108,17 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error { Buffer: props.Buffer, }, hook) - } else if globals.IsPalm2Model(props.Model) { - return palm2.NewChatInstanceFromConfig().CreateStreamChatRequest(&palm2.ChatProps{ - Model: props.Model, - Message: props.Message, - }, hook) - - } else if globals.IsSlackModel(props.Model) { - return slack.NewChatInstanceFromConfig().CreateStreamChatRequest(&slack.ChatProps{ - Message: props.Message, - }, hook) - - } else if globals.IsBingModel(props.Model) { - return bing.NewChatInstanceFromConfig().CreateStreamChatRequest(&bing.ChatProps{ - Model: props.Model, - Message: props.Message, - }, hook) - - } else if globals.IsZhiPuModel(props.Model) { + case globals.ChatGLMChannelType: return zhipuai.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhipuai.ChatProps{ - Model: props.Model, + Model: model, Message: props.Message, Temperature: props.Temperature, TopP: props.TopP, }, hook) - } else if globals.IsQwenModel(props.Model) { + case globals.QwenChannelType: return dashscope.NewChatInstanceFromConfig().CreateStreamChatRequest(&dashscope.ChatProps{ - Model: props.Model, + Model: model, Message: props.Message, Token: utils.Multi(props.Infinity || props.Plan, 2048, props.Token), Temperature: props.Temperature, @@ -144,43 +127,26 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error { RepetitionPenalty: props.RepetitionPenalty, }, hook) - } else if globals.IsMidjourneyModel(props.Model) { - return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{ - Model: props.Model, - Messages: props.Message, - }, hook) - - } else if globals.IsHunyuanModel(props.Model) { + case globals.HunyuanChannelType: return hunyuan.NewChatInstanceFromConfig().CreateStreamChatRequest(&hunyuan.ChatProps{ - Model: props.Model, + Model: model, Message: props.Message, Temperature: props.Temperature, TopP: props.TopP, }, hook) - } else if globals.Is360Model(props.Model) { - return zhinao.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhinao.ChatProps{ - Model: props.Model, - Message: props.Message, - Token: utils.Multi(props.Infinity || props.Plan, nil, utils.ToPtr(2048)), - TopP: props.TopP, - TopK: props.TopK, - Temperature: props.Temperature, - RepetitionPenalty: props.RepetitionPenalty, - }, hook) - - } else if globals.IsBaichuanModel(props.Model) { + case globals.BaichuanChannelType: return baichuan.NewChatInstanceFromConfig().CreateStreamChatRequest(&baichuan.ChatProps{ - Model: props.Model, + Model: model, Message: props.Message, TopP: props.TopP, TopK: props.TopK, Temperature: props.Temperature, }, hook) - } else if globals.IsSkylarkModel(props.Model) { + case globals.SkylarkChannelType: return skylark.NewChatInstanceFromConfig().CreateStreamChatRequest(&skylark.ChatProps{ - Model: props.Model, + Model: model, Message: props.Message, Token: utils.Multi(props.Token == 0, 4096, props.Token), TopP: props.TopP, @@ -191,7 +157,43 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error { RepeatPenalty: props.RepetitionPenalty, Tools: props.Tools, }, hook) - } - return hook("Sorry, we cannot find the model you are looking for. Please try another model.") + case globals.ZhinaoChannelType: + return zhinao.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhinao.ChatProps{ + Model: model, + Message: props.Message, + Token: utils.Multi(props.Infinity || props.Plan, nil, utils.ToPtr(2048)), + TopP: props.TopP, + TopK: props.TopK, + Temperature: props.Temperature, + RepetitionPenalty: props.RepetitionPenalty, + }, hook) + + case globals.MidjourneyChannelType: + return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{ + Model: model, + Messages: props.Message, + }, hook) + + case globals.OneAPIChannelType: + return oneapi.NewChatInstanceFromConfig().CreateStreamChatRequest(&oneapi.ChatProps{ + Model: model, + Message: props.Message, + Token: utils.Multi( + props.Token == 0, + utils.Multi(globals.IsGPT4Model(model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)), + &props.Token, + ), + PresencePenalty: props.PresencePenalty, + FrequencyPenalty: props.FrequencyPenalty, + Temperature: props.Temperature, + TopP: props.TopP, + Tools: props.Tools, + ToolChoice: props.ToolChoice, + Buffer: props.Buffer, + }, hook) + + default: + return fmt.Errorf("unknown channel type %s for model %s", conf.GetType(), props.Model) + } } diff --git a/adapter/chatgpt/struct.go b/adapter/chatgpt/struct.go index feafe4a..45e55bd 100644 --- a/adapter/chatgpt/struct.go +++ b/adapter/chatgpt/struct.go @@ -2,7 +2,6 @@ package chatgpt import ( "chat/globals" - "chat/utils" "fmt" "github.com/spf13/viper" ) @@ -42,7 +41,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance { func NewChatInstanceFromConfig(v string) *ChatInstance { return NewChatInstance( viper.GetString(fmt.Sprintf("openai.%s.endpoint", v)), - utils.GetRandomKey(viper.GetString(fmt.Sprintf("openai.%s.apikey", v))), + viper.GetString(fmt.Sprintf("openai.%s.apikey", v)), ) } diff --git a/adapter/claude/struct.go b/adapter/claude/struct.go index 56a7ecb..b682abc 100644 --- a/adapter/claude/struct.go +++ b/adapter/claude/struct.go @@ -1,7 +1,6 @@ package claude import ( - "chat/utils" "github.com/spf13/viper" ) @@ -20,7 +19,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance { func NewChatInstanceFromConfig() *ChatInstance { return NewChatInstance( viper.GetString("claude.endpoint"), - utils.GetRandomKey(viper.GetString("claude.apikey")), + viper.GetString("claude.apikey"), ) } diff --git a/adapter/dashscope/struct.go b/adapter/dashscope/struct.go index 7dcff53..55130b1 100644 --- a/adapter/dashscope/struct.go +++ b/adapter/dashscope/struct.go @@ -1,7 +1,6 @@ package dashscope import ( - "chat/utils" "github.com/spf13/viper" ) @@ -28,6 +27,6 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance { func NewChatInstanceFromConfig() *ChatInstance { return NewChatInstance( viper.GetString("dashscope.endpoint"), - utils.GetRandomKey(viper.GetString("dashscope.apikey")), + viper.GetString("dashscope.apikey"), ) } diff --git a/adapter/oneapi/globals.go b/adapter/oneapi/globals.go deleted file mode 100644 index f3d5531..0000000 --- a/adapter/oneapi/globals.go +++ /dev/null @@ -1,26 +0,0 @@ -package oneapi - -import ( - "chat/globals" -) - -var HitModels = []string{ - globals.Claude1, globals.Claude1100k, - globals.Claude2, globals.Claude2100k, - globals.StableDiffusion, - globals.LLaMa270B, globals.LLaMa213B, globals.LLaMa27B, - globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B, -} - -func (c *ChatInstance) GetToken(model string) int { - switch model { - case globals.Claude1, globals.Claude2: - return 5000 - case globals.Claude2100k, globals.Claude1100k: - return 50000 - case globals.LLaMa270B, globals.LLaMa213B, globals.LLaMa27B, globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B: - return 3000 - default: - return 2500 - } -} diff --git a/adapter/oneapi/struct.go b/adapter/oneapi/struct.go index 7038c51..a69d45a 100644 --- a/adapter/oneapi/struct.go +++ b/adapter/oneapi/struct.go @@ -1,7 +1,6 @@ package oneapi import ( - "chat/utils" "fmt" "github.com/spf13/viper" ) @@ -44,7 +43,3 @@ func NewChatInstanceFromConfig() *ChatInstance { viper.GetString("oneapi.apikey"), ) } - -func IsHit(model string) bool { - return utils.Contains[string](model, HitModels) -} diff --git a/adapter/palm2/struct.go b/adapter/palm2/struct.go index c458903..752d876 100644 --- a/adapter/palm2/struct.go +++ b/adapter/palm2/struct.go @@ -1,7 +1,6 @@ package palm2 import ( - "chat/utils" "github.com/spf13/viper" ) @@ -28,6 +27,6 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance { func NewChatInstanceFromConfig() *ChatInstance { return NewChatInstance( viper.GetString("palm2.endpoint"), - utils.GetRandomKey(viper.GetString("palm2.apikey")), + viper.GetString("palm2.apikey"), ) } diff --git a/adapter/request.go b/adapter/request.go index e43531c..6f23e04 100644 --- a/adapter/request.go +++ b/adapter/request.go @@ -20,21 +20,10 @@ func isQPSOverLimit(model string, err error) bool { } } -func getRetries(model string, retries *int) int { - if retries == nil { - if globals.IsMidjourneyModel(model) { - return midjourneyMaxRetries - } - return defaultMaxRetries - } +func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error { + err := createChatRequest(conf, props, hook) - return *retries -} - -func NewChatRequest(props *ChatProps, hook globals.Hook) error { - err := createChatRequest(props, hook) - - retries := getRetries(props.Model, props.MaxRetries) + retries := conf.GetRetry() props.Current++ if IsAvailableError(err) { @@ -43,14 +32,14 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error { fmt.Println(fmt.Sprintf("qps limit for %s, sleep and retry (times: %d)", props.Model, props.Current)) time.Sleep(500 * time.Millisecond) - return NewChatRequest(props, hook) + return NewChatRequest(conf, props, hook) } if props.Current < retries { fmt.Println(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, err.Error())) - return NewChatRequest(props, hook) + return NewChatRequest(conf, props, hook) } } - return err + return conf.ProcessError(err) } diff --git a/addition/generation/prompt.go b/addition/generation/prompt.go index 65a6cd9..472887e 100644 --- a/addition/generation/prompt.go +++ b/addition/generation/prompt.go @@ -3,6 +3,7 @@ package generation import ( "chat/adapter" "chat/admin" + "chat/channel" "chat/globals" "chat/utils" "fmt" @@ -16,7 +17,7 @@ func CreateGeneration(model string, prompt string, path string, plan bool, hook message := GenerateMessage(prompt) buffer := utils.NewBuffer(model, message) - err := adapter.NewChatRequest(&adapter.ChatProps{ + err := channel.NewChatRequest(&adapter.ChatProps{ Model: model, Message: message, Plan: plan, diff --git a/channel/channel.go b/channel/channel.go index 3b31010..90c9eb0 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -2,7 +2,8 @@ package channel import ( "chat/utils" - "math/rand" + "errors" + "fmt" "strings" ) @@ -46,9 +47,10 @@ func (c *Channel) GetSecret() string { return c.Secret } +// GetRandomSecret returns a random secret from the secret list func (c *Channel) GetRandomSecret() string { arr := strings.Split(c.GetSecret(), "\n") - idx := rand.Intn(len(arr)) + idx := utils.Intn(len(arr)) return arr[idx] } @@ -77,6 +79,7 @@ func (c *Channel) GetReflect() map[string]string { return *c.Reflect } +// GetModelReflect returns the reflection model name if it exists, otherwise returns the original model name func (c *Channel) GetModelReflect(model string) string { ref := c.GetReflect() if reflect, ok := ref[model]; ok && len(reflect) > 0 { @@ -114,3 +117,21 @@ func (c *Channel) GetHitModels() []string { func (c *Channel) GetState() bool { return c.State } + +func (c *Channel) IsHit(model string) bool { + return utils.Contains(model, c.GetHitModels()) +} + +func (c *Channel) ProcessError(err error) error { + if err == nil { + return nil + } + content := err.Error() + if strings.Contains(content, c.GetEndpoint()) { + // hide the endpoint + replacer := fmt.Sprintf("channel://%d", c.GetId()) + content = strings.Replace(content, c.GetEndpoint(), replacer, -1) + } + + return errors.New(content) +} diff --git a/channel/manager.go b/channel/manager.go new file mode 100644 index 0000000..d272645 --- /dev/null +++ b/channel/manager.go @@ -0,0 +1,153 @@ +package channel + +import ( + "chat/utils" + "github.com/spf13/viper" +) + +var ManagerInstance *Manager + +func InitManager() { + ManagerInstance = NewManager() +} + +func NewManager() *Manager { + var seq Sequence + if err := viper.UnmarshalKey("channel", &seq); err != nil { + panic(err) + } + + // sort by priority + + manager := &Manager{ + Sequence: seq, + } + + return manager +} + +func (m *Manager) Init() { + // init support models + for _, channel := range m.GetActiveSequence() { + for _, model := range channel.GetModels() { + if !utils.Contains(model, m.Models) { + m.Models = append(m.Models, model) + } + } + } + + // init preflight sequence + for _, model := range m.Models { + var seq Sequence + for _, channel := range m.GetActiveSequence() { + if utils.Contains(model, channel.GetModels()) { + seq = append(seq, channel) + } + } + seq.Sort() + m.PreflightSequence[model] = seq + } +} + +func (m *Manager) GetSequence() Sequence { + return m.Sequence +} + +func (m *Manager) GetActiveSequence() Sequence { + var seq Sequence + for _, channel := range m.Sequence { + if channel.GetState() { + seq = append(seq, channel) + } + } + seq.Sort() + return seq +} + +func (m *Manager) GetModels() []string { + return m.Models +} + +func (m *Manager) GetPreflightSequence() map[string]Sequence { + return m.PreflightSequence +} + +// HitSequence returns the preflight sequence of the model +func (m *Manager) HitSequence(model string) Sequence { + return m.PreflightSequence[model] +} + +// HasChannel returns whether the channel exists +func (m *Manager) HasChannel(model string) bool { + return utils.Contains(model, m.Models) +} + +func (m *Manager) GetTicker(model string) *Ticker { + return &Ticker{ + Sequence: m.HitSequence(model), + } +} + +func (m *Manager) Len() int { + return len(m.Sequence) +} + +func (m *Manager) GetMaxId() int { + var max int + for _, channel := range m.Sequence { + if channel.Id > max { + max = channel.Id + } + } + return max +} + +func (m *Manager) SaveConfig() error { + return viper.UnmarshalKey("channel", &m.Sequence) +} + +func (m *Manager) CreateChannel(channel *Channel) error { + channel.Id = m.GetMaxId() + 1 + m.Sequence = append(m.Sequence, channel) + return m.SaveConfig() +} + +func (m *Manager) UpdateChannel(id int, channel *Channel) error { + for i, item := range m.Sequence { + if item.Id == id { + m.Sequence[i] = channel + return m.SaveConfig() + } + } + return nil +} + +func (m *Manager) DeleteChannel(id int) error { + for i, item := range m.Sequence { + if item.Id == id { + m.Sequence = append(m.Sequence[:i], m.Sequence[i+1:]...) + return m.SaveConfig() + } + } + return nil +} + +func (m *Manager) ActivateChannel(id int) error { + for i, item := range m.Sequence { + if item.Id == id { + m.Sequence[i].State = true + return m.SaveConfig() + } + } + return nil +} + +func (m *Manager) DeactivateChannel(id int) error { + for i, item := range m.Sequence { + if item.Id == id { + m.Sequence[i].State = false + return m.SaveConfig() + } + } + return nil +} diff --git a/channel/sequence.go b/channel/sequence.go new file mode 100644 index 0000000..81e5be4 --- /dev/null +++ b/channel/sequence.go @@ -0,0 +1,29 @@ +package channel + +import "sort" + +func (s *Sequence) Len() int { + return len(*s) +} + +func (s *Sequence) Less(i, j int) bool { + return (*s)[i].GetPriority() < (*s)[j].GetPriority() +} + +func (s *Sequence) Swap(i, j int) { + (*s)[i], (*s)[j] = (*s)[j], (*s)[i] +} + +func (s *Sequence) GetChannelById(id int) *Channel { + for _, channel := range *s { + if channel.Id == id { + return channel + } + } + return nil +} + +func (s *Sequence) Sort() { + // sort by priority + sort.Sort(s) +} diff --git a/channel/ticker.go b/channel/ticker.go new file mode 100644 index 0000000..436ad68 --- /dev/null +++ b/channel/ticker.go @@ -0,0 +1,81 @@ +package channel + +import "chat/utils" + +func (t *Ticker) GetChannelByPriority(priority int) *Channel { + var stack Sequence + + for idx, channel := range t.Sequence { + if channel.GetPriority() == priority { + // get if the next channel has the same priority + if idx+1 < len(t.Sequence) && t.Sequence[idx+1].GetPriority() == priority { + stack = append(stack, channel) + continue + } + + if len(stack) == 0 { + return channel + } + + // stack is not empty + stack = append(stack, channel) + + // sort by weight and break the loop + if idx+1 >= len(t.Sequence) || t.Sequence[idx+1].GetPriority() != priority { + stack.Sort() + break + } + } + } + + weight := utils.Each(stack, func(channel *Channel) int { + return channel.GetWeight() + }) + total := utils.Sum(weight) + + // get random number + cursor := utils.Intn(total) + + // get channel by weight + for _, channel := range stack { + cursor -= channel.GetWeight() + if cursor < 0 { + return channel + } + } + + return stack[0] +} + +func (t *Ticker) Next() *Channel { + if t.Cursor >= len(t.Sequence) { + // out of sequence + return nil + } + + priority := t.Sequence[t.Cursor].GetPriority() + channel := t.GetChannelByPriority(priority) + t.SkipPriority(priority) + + return channel +} + +func (t *Ticker) SkipPriority(priority int) { + for idx, channel := range t.Sequence { + if channel.GetPriority() == priority { + // get if the next channel does not have the same priority or out of sequence + if idx+1 >= len(t.Sequence) || t.Sequence[idx+1].GetPriority() != priority { + t.Cursor = idx + 1 + break + } + } + } +} + +func (t *Ticker) Skip() { + t.Cursor++ +} + +func (t *Ticker) IsDone() bool { + return t.Cursor >= len(t.Sequence) +} diff --git a/channel/types.go b/channel/types.go index 77610ed..61763c0 100644 --- a/channel/types.go +++ b/channel/types.go @@ -1,24 +1,30 @@ package channel type Channel struct { - Id int `json:"id"` - Name string `json:"name"` - Type string `json:"type"` - Priority int `json:"priority"` - Weight int `json:"weight"` - Models []string `json:"models"` - Retry int `json:"retry"` - Secret string `json:"secret"` - Endpoint string `json:"endpoint"` - Mapper string `json:"mapper"` - State bool `json:"state"` - - Reflect *map[string]string `json:"reflect"` - HitModels *[]string `json:"hit_models"` + Id int `json:"id" mapstructure:"id"` + Name string `json:"name" mapstructure:"name"` + Type string `json:"type" mapstructure:"type"` + Priority int `json:"priority" mapstructure:"priority"` + Weight int `json:"weight" mapstructure:"weight"` + Models []string `json:"models" mapstructure:"models"` + Retry int `json:"retry" mapstructure:"retry"` + Secret string `json:"secret" mapstructure:"secret"` + Endpoint string `json:"endpoint" mapstructure:"endpoint"` + Mapper string `json:"mapper" mapstructure:"mapper"` + State bool `json:"state" mapstructure:"state"` + Reflect *map[string]string `json:"reflect" mapstructure:"reflect"` + HitModels *[]string `json:"hit_models" mapstructure:"hit_models"` } type Sequence []*Channel type Manager struct { - Sequence Sequence `json:"sequence"` + Sequence Sequence `json:"sequence"` + PreflightSequence map[string]Sequence `json:"preflight_sequence"` + Models []string `json:"models"` +} + +type Ticker struct { + Sequence Sequence `json:"sequence"` + Cursor int `json:"cursor"` } diff --git a/channel/worker.go b/channel/worker.go new file mode 100644 index 0000000..df1131a --- /dev/null +++ b/channel/worker.go @@ -0,0 +1,29 @@ +package channel + +import ( + "chat/adapter" + "chat/globals" + "chat/utils" + "fmt" +) + +func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error { + if !ManagerInstance.HasChannel(props.Model) { + return fmt.Errorf("cannot find channel for model %s", props.Model) + } + + ticker := ManagerInstance.GetTicker(props.Model) + + var err error + for !ticker.IsDone() { + if channel := ticker.Next(); channel != nil { + props.MaxRetries = utils.ToPtr(channel.GetRetry()) + if err = adapter.NewChatRequest(channel, props, hook); err == nil { + return nil + } + fmt.Println(fmt.Sprintf("[channel] hit error %s for model %s, goto next channel", err.Error(), props.Model)) + } + } + + return err +} diff --git a/globals/constant.go b/globals/constant.go index 1d1dab2..35434c1 100644 --- a/globals/constant.go +++ b/globals/constant.go @@ -8,18 +8,18 @@ const ( ) const ( - OpenAIChannelType = iota - ClaudeChannelType - SlackChannelType - SparkdeskChannelType - ChatGLMChannelType - DashscopeChannelType - HunyuanChannelType - ZhinaoChannelType - BaichuanChannelType - SkylarkChannelType - BingChannelType - PalmChannelType - MidjourneyChannelType - OneAPIChannelType + OpenAIChannelType = "openai" + ClaudeChannelType = "claude" + SlackChannelType = "slack" + SparkdeskChannelType = "sparkdesk" + ChatGLMChannelType = "chatglm" + QwenChannelType = "qwen" + HunyuanChannelType = "hunyuan" + ZhinaoChannelType = "zhinao" + BaichuanChannelType = "baichuan" + SkylarkChannelType = "skylark" + BingChannelType = "bing" + PalmChannelType = "palm" + MidjourneyChannelType = "midjourney" + OneAPIChannelType = "oneapi" ) diff --git a/globals/interface.go b/globals/interface.go new file mode 100644 index 0000000..2c3ab3b --- /dev/null +++ b/globals/interface.go @@ -0,0 +1,10 @@ +package globals + +type ChannelConfig interface { + GetType() string + GetModelReflect(model string) string + GetRetry() int + GetRandomSecret() string + GetEndpoint() string + ProcessError(err error) error +} diff --git a/main.go b/main.go index 75f6856..afd077f 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "chat/addition" "chat/admin" "chat/auth" + "chat/channel" "chat/cli" "chat/manager" "chat/manager/conversation" @@ -23,6 +24,7 @@ func main() { if cli.Run() { return } + channel.InitManager() app := gin.Default() middleware.RegisterMiddleware(app) diff --git a/manager/chat.go b/manager/chat.go index 563f162..3978527 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -5,6 +5,7 @@ import ( "chat/addition/web" "chat/admin" "chat/auth" + "chat/channel" "chat/globals" "chat/manager/conversation" "chat/utils" @@ -86,7 +87,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve } buffer := utils.NewBuffer(model, segment) - err := adapter.NewChatRequest(&adapter.ChatProps{ + err := channel.NewChatRequest(&adapter.ChatProps{ Model: model, Message: segment, Plan: plan, diff --git a/manager/completions.go b/manager/completions.go index f86aefa..c55dd7b 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -5,6 +5,7 @@ import ( "chat/addition/web" "chat/admin" "chat/auth" + "chat/channel" "chat/globals" "chat/utils" "fmt" @@ -39,7 +40,7 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] } buffer := utils.NewBuffer(model, segment) - err := adapter.NewChatRequest(&adapter.ChatProps{ + err := channel.NewChatRequest(&adapter.ChatProps{ Model: model, Plan: plan, Message: segment, diff --git a/manager/transhipment.go b/manager/transhipment.go index 41f20a9..45fb444 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -5,6 +5,7 @@ import ( "chat/addition/web" "chat/admin" "chat/auth" + "chat/channel" "chat/globals" "chat/utils" "fmt" @@ -70,7 +71,7 @@ type TranshipmentStreamResponse struct { } func ModelAPI(c *gin.Context) { - c.JSON(http.StatusOK, globals.AllModels) + c.JSON(http.StatusOK, channel.ManagerInstance.GetModels()) } func TranshipmentAPI(c *gin.Context) { @@ -162,7 +163,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, cache := utils.GetCacheFromContext(c) buffer := utils.NewBuffer(form.Model, form.Messages) - err := adapter.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { + err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { buffer.Write(data) return nil }) @@ -221,34 +222,34 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, } func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, plan bool) { - channel := make(chan TranshipmentStreamResponse) + partial := make(chan TranshipmentStreamResponse) db := utils.GetDBFromContext(c) cache := utils.GetCacheFromContext(c) go func() { buffer := utils.NewBuffer(form.Model, form.Messages) - err := adapter.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { - channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false) + err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error { + partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false) return nil }) admin.AnalysisRequest(form.Model, buffer, err) if err != nil { auth.RevertSubscriptionUsage(db, cache, user, form.Model) - channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) + partial <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) CollectQuota(c, user, buffer, plan) - close(channel) + close(partial) return } - channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true) + partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true) CollectQuota(c, user, buffer, plan) - close(channel) + close(partial) return }() c.Stream(func(w io.Writer) bool { - if resp, ok := <-channel; ok { + if resp, ok := <-partial; ok { c.Render(-1, utils.NewEvent(resp)) return true } diff --git a/utils/base.go b/utils/base.go index 96d125a..3029fae 100644 --- a/utils/base.go +++ b/utils/base.go @@ -3,9 +3,16 @@ package utils import ( "fmt" "github.com/goccy/go-json" + "math/rand" "time" ) +func Intn(n int) int { + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + return r.Intn(n) +} + func Sum[T int | int64 | float32 | float64](arr []T) T { var res T for _, v := range arr { diff --git a/utils/char.go b/utils/char.go index 42c00e9..6b6417b 100644 --- a/utils/char.go +++ b/utils/char.go @@ -3,7 +3,6 @@ package utils import ( "fmt" "github.com/goccy/go-json" - "math/rand" "regexp" "strconv" "strings" @@ -11,13 +10,13 @@ import ( ) func GetRandomInt(min int, max int) int { - return rand.Intn(max-min) + min + return Intn(max-min) + min } func GenerateCode(length int) string { var code string for i := 0; i < length; i++ { - code += strconv.Itoa(rand.Intn(10)) + code += strconv.Itoa(Intn(10)) } return code } @@ -26,7 +25,7 @@ func GenerateChar(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" result := make([]byte, length) for i := 0; i < length; i++ { - result[i] = charset[rand.Intn(len(charset))] + result[i] = charset[Intn(len(charset))] } return string(result) } diff --git a/utils/key.go b/utils/key.go deleted file mode 100644 index fdece80..0000000 --- a/utils/key.go +++ /dev/null @@ -1,12 +0,0 @@ -package utils - -import ( - "math/rand" - "strings" -) - -func GetRandomKey(apikey string) string { - arr := strings.Split(apikey, "|") - idx := rand.Intn(len(arr)) - return arr[idx] -}