mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
update: channel worker, channel sequence and ticker
This commit is contained in:
parent
db7acee643
commit
8e3a424e60
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@ node_modules
|
||||
.vscode
|
||||
.idea
|
||||
config.yaml
|
||||
config.dev.yaml
|
||||
|
||||
addition/generation/data/*
|
||||
!addition/generation/data/.gitkeep
|
||||
|
@ -17,11 +17,9 @@ import (
|
||||
"chat/adapter/zhipuai"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var defaultMaxRetries = 3
|
||||
var midjourneyMaxRetries = 10
|
||||
|
||||
type RequestProps struct {
|
||||
MaxRetries *int
|
||||
Current int
|
||||
@ -46,36 +44,21 @@ type ChatProps struct {
|
||||
Buffer utils.Buffer
|
||||
}
|
||||
|
||||
func createChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
if oneapi.IsHit(props.Model) {
|
||||
return oneapi.NewChatInstanceFromConfig().CreateStreamChatRequest(&oneapi.ChatProps{
|
||||
Model: props.Model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(
|
||||
props.Token == 0,
|
||||
utils.Multi(globals.IsGPT4Model(props.Model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
|
||||
&props.Token,
|
||||
),
|
||||
PresencePenalty: props.PresencePenalty,
|
||||
FrequencyPenalty: props.FrequencyPenalty,
|
||||
Temperature: props.Temperature,
|
||||
TopP: props.TopP,
|
||||
Tools: props.Tools,
|
||||
ToolChoice: props.ToolChoice,
|
||||
Buffer: props.Buffer,
|
||||
}, hook)
|
||||
func createChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error {
|
||||
model := conf.GetModelReflect(props.Model)
|
||||
|
||||
} else if globals.IsChatGPTModel(props.Model) {
|
||||
switch conf.GetType() {
|
||||
case globals.OpenAIChannelType:
|
||||
instance := chatgpt.NewChatInstanceFromModel(&chatgpt.InstanceProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Plan: props.Plan,
|
||||
})
|
||||
return instance.CreateStreamChatRequest(&chatgpt.ChatProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(
|
||||
props.Token == 0,
|
||||
utils.Multi(globals.IsGPT4Model(props.Model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
|
||||
utils.Multi(globals.IsGPT4Model(model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
|
||||
&props.Token,
|
||||
),
|
||||
PresencePenalty: props.PresencePenalty,
|
||||
@ -87,9 +70,9 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
Buffer: props.Buffer,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsClaudeModel(props.Model) {
|
||||
case globals.ClaudeChannelType:
|
||||
return claude.NewChatInstanceFromConfig().CreateStreamChatRequest(&claude.ChatProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(props.Token == 0, 50000, props.Token),
|
||||
TopP: props.TopP,
|
||||
@ -97,9 +80,26 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
Temperature: props.Temperature,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsSparkDeskModel(props.Model) {
|
||||
return sparkdesk.NewChatInstance(props.Model).CreateStreamChatRequest(&sparkdesk.ChatProps{
|
||||
Model: props.Model,
|
||||
case globals.SlackChannelType:
|
||||
return slack.NewChatInstanceFromConfig().CreateStreamChatRequest(&slack.ChatProps{
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
|
||||
case globals.BingChannelType:
|
||||
return bing.NewChatInstanceFromConfig().CreateStreamChatRequest(&bing.ChatProps{
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
|
||||
case globals.PalmChannelType:
|
||||
return palm2.NewChatInstanceFromConfig().CreateStreamChatRequest(&palm2.ChatProps{
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
|
||||
case globals.SparkdeskChannelType:
|
||||
return sparkdesk.NewChatInstance(model).CreateStreamChatRequest(&sparkdesk.ChatProps{
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(props.Token == 0, nil, utils.ToPtr(props.Token)),
|
||||
Temperature: props.Temperature,
|
||||
@ -108,34 +108,17 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
Buffer: props.Buffer,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsPalm2Model(props.Model) {
|
||||
return palm2.NewChatInstanceFromConfig().CreateStreamChatRequest(&palm2.ChatProps{
|
||||
Model: props.Model,
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsSlackModel(props.Model) {
|
||||
return slack.NewChatInstanceFromConfig().CreateStreamChatRequest(&slack.ChatProps{
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsBingModel(props.Model) {
|
||||
return bing.NewChatInstanceFromConfig().CreateStreamChatRequest(&bing.ChatProps{
|
||||
Model: props.Model,
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsZhiPuModel(props.Model) {
|
||||
case globals.ChatGLMChannelType:
|
||||
return zhipuai.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhipuai.ChatProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Temperature: props.Temperature,
|
||||
TopP: props.TopP,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsQwenModel(props.Model) {
|
||||
case globals.QwenChannelType:
|
||||
return dashscope.NewChatInstanceFromConfig().CreateStreamChatRequest(&dashscope.ChatProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(props.Infinity || props.Plan, 2048, props.Token),
|
||||
Temperature: props.Temperature,
|
||||
@ -144,43 +127,26 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
RepetitionPenalty: props.RepetitionPenalty,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsMidjourneyModel(props.Model) {
|
||||
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
|
||||
Model: props.Model,
|
||||
Messages: props.Message,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsHunyuanModel(props.Model) {
|
||||
case globals.HunyuanChannelType:
|
||||
return hunyuan.NewChatInstanceFromConfig().CreateStreamChatRequest(&hunyuan.ChatProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Temperature: props.Temperature,
|
||||
TopP: props.TopP,
|
||||
}, hook)
|
||||
|
||||
} else if globals.Is360Model(props.Model) {
|
||||
return zhinao.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhinao.ChatProps{
|
||||
Model: props.Model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(props.Infinity || props.Plan, nil, utils.ToPtr(2048)),
|
||||
TopP: props.TopP,
|
||||
TopK: props.TopK,
|
||||
Temperature: props.Temperature,
|
||||
RepetitionPenalty: props.RepetitionPenalty,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsBaichuanModel(props.Model) {
|
||||
case globals.BaichuanChannelType:
|
||||
return baichuan.NewChatInstanceFromConfig().CreateStreamChatRequest(&baichuan.ChatProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
TopP: props.TopP,
|
||||
TopK: props.TopK,
|
||||
Temperature: props.Temperature,
|
||||
}, hook)
|
||||
|
||||
} else if globals.IsSkylarkModel(props.Model) {
|
||||
case globals.SkylarkChannelType:
|
||||
return skylark.NewChatInstanceFromConfig().CreateStreamChatRequest(&skylark.ChatProps{
|
||||
Model: props.Model,
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(props.Token == 0, 4096, props.Token),
|
||||
TopP: props.TopP,
|
||||
@ -191,7 +157,43 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
RepeatPenalty: props.RepetitionPenalty,
|
||||
Tools: props.Tools,
|
||||
}, hook)
|
||||
}
|
||||
|
||||
return hook("Sorry, we cannot find the model you are looking for. Please try another model.")
|
||||
case globals.ZhinaoChannelType:
|
||||
return zhinao.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhinao.ChatProps{
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(props.Infinity || props.Plan, nil, utils.ToPtr(2048)),
|
||||
TopP: props.TopP,
|
||||
TopK: props.TopK,
|
||||
Temperature: props.Temperature,
|
||||
RepetitionPenalty: props.RepetitionPenalty,
|
||||
}, hook)
|
||||
|
||||
case globals.MidjourneyChannelType:
|
||||
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
|
||||
Model: model,
|
||||
Messages: props.Message,
|
||||
}, hook)
|
||||
|
||||
case globals.OneAPIChannelType:
|
||||
return oneapi.NewChatInstanceFromConfig().CreateStreamChatRequest(&oneapi.ChatProps{
|
||||
Model: model,
|
||||
Message: props.Message,
|
||||
Token: utils.Multi(
|
||||
props.Token == 0,
|
||||
utils.Multi(globals.IsGPT4Model(model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
|
||||
&props.Token,
|
||||
),
|
||||
PresencePenalty: props.PresencePenalty,
|
||||
FrequencyPenalty: props.FrequencyPenalty,
|
||||
Temperature: props.Temperature,
|
||||
TopP: props.TopP,
|
||||
Tools: props.Tools,
|
||||
ToolChoice: props.ToolChoice,
|
||||
Buffer: props.Buffer,
|
||||
}, hook)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown channel type %s for model %s", conf.GetType(), props.Model)
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package chatgpt
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@ -42,7 +41,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance {
|
||||
func NewChatInstanceFromConfig(v string) *ChatInstance {
|
||||
return NewChatInstance(
|
||||
viper.GetString(fmt.Sprintf("openai.%s.endpoint", v)),
|
||||
utils.GetRandomKey(viper.GetString(fmt.Sprintf("openai.%s.apikey", v))),
|
||||
viper.GetString(fmt.Sprintf("openai.%s.apikey", v)),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@ -20,7 +19,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance {
|
||||
func NewChatInstanceFromConfig() *ChatInstance {
|
||||
return NewChatInstance(
|
||||
viper.GetString("claude.endpoint"),
|
||||
utils.GetRandomKey(viper.GetString("claude.apikey")),
|
||||
viper.GetString("claude.apikey"),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package dashscope
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@ -28,6 +27,6 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance {
|
||||
func NewChatInstanceFromConfig() *ChatInstance {
|
||||
return NewChatInstance(
|
||||
viper.GetString("dashscope.endpoint"),
|
||||
utils.GetRandomKey(viper.GetString("dashscope.apikey")),
|
||||
viper.GetString("dashscope.apikey"),
|
||||
)
|
||||
}
|
||||
|
@ -1,26 +0,0 @@
|
||||
package oneapi
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
)
|
||||
|
||||
var HitModels = []string{
|
||||
globals.Claude1, globals.Claude1100k,
|
||||
globals.Claude2, globals.Claude2100k,
|
||||
globals.StableDiffusion,
|
||||
globals.LLaMa270B, globals.LLaMa213B, globals.LLaMa27B,
|
||||
globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B,
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetToken(model string) int {
|
||||
switch model {
|
||||
case globals.Claude1, globals.Claude2:
|
||||
return 5000
|
||||
case globals.Claude2100k, globals.Claude1100k:
|
||||
return 50000
|
||||
case globals.LLaMa270B, globals.LLaMa213B, globals.LLaMa27B, globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B:
|
||||
return 3000
|
||||
default:
|
||||
return 2500
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package oneapi
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@ -44,7 +43,3 @@ func NewChatInstanceFromConfig() *ChatInstance {
|
||||
viper.GetString("oneapi.apikey"),
|
||||
)
|
||||
}
|
||||
|
||||
func IsHit(model string) bool {
|
||||
return utils.Contains[string](model, HitModels)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package palm2
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@ -28,6 +27,6 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance {
|
||||
func NewChatInstanceFromConfig() *ChatInstance {
|
||||
return NewChatInstance(
|
||||
viper.GetString("palm2.endpoint"),
|
||||
utils.GetRandomKey(viper.GetString("palm2.apikey")),
|
||||
viper.GetString("palm2.apikey"),
|
||||
)
|
||||
}
|
||||
|
@ -20,21 +20,10 @@ func isQPSOverLimit(model string, err error) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func getRetries(model string, retries *int) int {
|
||||
if retries == nil {
|
||||
if globals.IsMidjourneyModel(model) {
|
||||
return midjourneyMaxRetries
|
||||
}
|
||||
return defaultMaxRetries
|
||||
}
|
||||
func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error {
|
||||
err := createChatRequest(conf, props, hook)
|
||||
|
||||
return *retries
|
||||
}
|
||||
|
||||
func NewChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
err := createChatRequest(props, hook)
|
||||
|
||||
retries := getRetries(props.Model, props.MaxRetries)
|
||||
retries := conf.GetRetry()
|
||||
props.Current++
|
||||
|
||||
if IsAvailableError(err) {
|
||||
@ -43,14 +32,14 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
|
||||
fmt.Println(fmt.Sprintf("qps limit for %s, sleep and retry (times: %d)", props.Model, props.Current))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return NewChatRequest(props, hook)
|
||||
return NewChatRequest(conf, props, hook)
|
||||
}
|
||||
|
||||
if props.Current < retries {
|
||||
fmt.Println(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, err.Error()))
|
||||
return NewChatRequest(props, hook)
|
||||
return NewChatRequest(conf, props, hook)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
return conf.ProcessError(err)
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package generation
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/admin"
|
||||
"chat/channel"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
@ -16,7 +17,7 @@ func CreateGeneration(model string, prompt string, path string, plan bool, hook
|
||||
message := GenerateMessage(prompt)
|
||||
buffer := utils.NewBuffer(model, message)
|
||||
|
||||
err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
err := channel.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Message: message,
|
||||
Plan: plan,
|
||||
|
@ -2,7 +2,8 @@ package channel
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"math/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -46,9 +47,10 @@ func (c *Channel) GetSecret() string {
|
||||
return c.Secret
|
||||
}
|
||||
|
||||
// GetRandomSecret returns a random secret from the secret list
|
||||
func (c *Channel) GetRandomSecret() string {
|
||||
arr := strings.Split(c.GetSecret(), "\n")
|
||||
idx := rand.Intn(len(arr))
|
||||
idx := utils.Intn(len(arr))
|
||||
return arr[idx]
|
||||
}
|
||||
|
||||
@ -77,6 +79,7 @@ func (c *Channel) GetReflect() map[string]string {
|
||||
return *c.Reflect
|
||||
}
|
||||
|
||||
// GetModelReflect returns the reflection model name if it exists, otherwise returns the original model name
|
||||
func (c *Channel) GetModelReflect(model string) string {
|
||||
ref := c.GetReflect()
|
||||
if reflect, ok := ref[model]; ok && len(reflect) > 0 {
|
||||
@ -114,3 +117,21 @@ func (c *Channel) GetHitModels() []string {
|
||||
func (c *Channel) GetState() bool {
|
||||
return c.State
|
||||
}
|
||||
|
||||
func (c *Channel) IsHit(model string) bool {
|
||||
return utils.Contains(model, c.GetHitModels())
|
||||
}
|
||||
|
||||
func (c *Channel) ProcessError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
content := err.Error()
|
||||
if strings.Contains(content, c.GetEndpoint()) {
|
||||
// hide the endpoint
|
||||
replacer := fmt.Sprintf("channel://%d", c.GetId())
|
||||
content = strings.Replace(content, c.GetEndpoint(), replacer, -1)
|
||||
}
|
||||
|
||||
return errors.New(content)
|
||||
}
|
||||
|
153
channel/manager.go
Normal file
153
channel/manager.go
Normal file
@ -0,0 +1,153 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var ManagerInstance *Manager
|
||||
|
||||
func InitManager() {
|
||||
ManagerInstance = NewManager()
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
var seq Sequence
|
||||
if err := viper.UnmarshalKey("channel", &seq); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
|
||||
manager := &Manager{
|
||||
Sequence: seq,
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
func (m *Manager) Init() {
|
||||
// init support models
|
||||
for _, channel := range m.GetActiveSequence() {
|
||||
for _, model := range channel.GetModels() {
|
||||
if !utils.Contains(model, m.Models) {
|
||||
m.Models = append(m.Models, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// init preflight sequence
|
||||
for _, model := range m.Models {
|
||||
var seq Sequence
|
||||
for _, channel := range m.GetActiveSequence() {
|
||||
if utils.Contains(model, channel.GetModels()) {
|
||||
seq = append(seq, channel)
|
||||
}
|
||||
}
|
||||
seq.Sort()
|
||||
m.PreflightSequence[model] = seq
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) GetSequence() Sequence {
|
||||
return m.Sequence
|
||||
}
|
||||
|
||||
func (m *Manager) GetActiveSequence() Sequence {
|
||||
var seq Sequence
|
||||
for _, channel := range m.Sequence {
|
||||
if channel.GetState() {
|
||||
seq = append(seq, channel)
|
||||
}
|
||||
}
|
||||
seq.Sort()
|
||||
return seq
|
||||
}
|
||||
|
||||
func (m *Manager) GetModels() []string {
|
||||
return m.Models
|
||||
}
|
||||
|
||||
func (m *Manager) GetPreflightSequence() map[string]Sequence {
|
||||
return m.PreflightSequence
|
||||
}
|
||||
|
||||
// HitSequence returns the preflight sequence of the model
|
||||
func (m *Manager) HitSequence(model string) Sequence {
|
||||
return m.PreflightSequence[model]
|
||||
}
|
||||
|
||||
// HasChannel returns whether the channel exists
|
||||
func (m *Manager) HasChannel(model string) bool {
|
||||
return utils.Contains(model, m.Models)
|
||||
}
|
||||
|
||||
func (m *Manager) GetTicker(model string) *Ticker {
|
||||
return &Ticker{
|
||||
Sequence: m.HitSequence(model),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Len() int {
|
||||
return len(m.Sequence)
|
||||
}
|
||||
|
||||
func (m *Manager) GetMaxId() int {
|
||||
var max int
|
||||
for _, channel := range m.Sequence {
|
||||
if channel.Id > max {
|
||||
max = channel.Id
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func (m *Manager) SaveConfig() error {
|
||||
return viper.UnmarshalKey("channel", &m.Sequence)
|
||||
}
|
||||
|
||||
func (m *Manager) CreateChannel(channel *Channel) error {
|
||||
channel.Id = m.GetMaxId() + 1
|
||||
m.Sequence = append(m.Sequence, channel)
|
||||
return m.SaveConfig()
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateChannel(id int, channel *Channel) error {
|
||||
for i, item := range m.Sequence {
|
||||
if item.Id == id {
|
||||
m.Sequence[i] = channel
|
||||
return m.SaveConfig()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteChannel(id int) error {
|
||||
for i, item := range m.Sequence {
|
||||
if item.Id == id {
|
||||
m.Sequence = append(m.Sequence[:i], m.Sequence[i+1:]...)
|
||||
return m.SaveConfig()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) ActivateChannel(id int) error {
|
||||
for i, item := range m.Sequence {
|
||||
if item.Id == id {
|
||||
m.Sequence[i].State = true
|
||||
return m.SaveConfig()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeactivateChannel(id int) error {
|
||||
for i, item := range m.Sequence {
|
||||
if item.Id == id {
|
||||
m.Sequence[i].State = false
|
||||
return m.SaveConfig()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
29
channel/sequence.go
Normal file
29
channel/sequence.go
Normal file
@ -0,0 +1,29 @@
|
||||
package channel
|
||||
|
||||
import "sort"
|
||||
|
||||
func (s *Sequence) Len() int {
|
||||
return len(*s)
|
||||
}
|
||||
|
||||
func (s *Sequence) Less(i, j int) bool {
|
||||
return (*s)[i].GetPriority() < (*s)[j].GetPriority()
|
||||
}
|
||||
|
||||
func (s *Sequence) Swap(i, j int) {
|
||||
(*s)[i], (*s)[j] = (*s)[j], (*s)[i]
|
||||
}
|
||||
|
||||
func (s *Sequence) GetChannelById(id int) *Channel {
|
||||
for _, channel := range *s {
|
||||
if channel.Id == id {
|
||||
return channel
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Sequence) Sort() {
|
||||
// sort by priority
|
||||
sort.Sort(s)
|
||||
}
|
81
channel/ticker.go
Normal file
81
channel/ticker.go
Normal file
@ -0,0 +1,81 @@
|
||||
package channel
|
||||
|
||||
import "chat/utils"
|
||||
|
||||
func (t *Ticker) GetChannelByPriority(priority int) *Channel {
|
||||
var stack Sequence
|
||||
|
||||
for idx, channel := range t.Sequence {
|
||||
if channel.GetPriority() == priority {
|
||||
// get if the next channel has the same priority
|
||||
if idx+1 < len(t.Sequence) && t.Sequence[idx+1].GetPriority() == priority {
|
||||
stack = append(stack, channel)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(stack) == 0 {
|
||||
return channel
|
||||
}
|
||||
|
||||
// stack is not empty
|
||||
stack = append(stack, channel)
|
||||
|
||||
// sort by weight and break the loop
|
||||
if idx+1 >= len(t.Sequence) || t.Sequence[idx+1].GetPriority() != priority {
|
||||
stack.Sort()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weight := utils.Each(stack, func(channel *Channel) int {
|
||||
return channel.GetWeight()
|
||||
})
|
||||
total := utils.Sum(weight)
|
||||
|
||||
// get random number
|
||||
cursor := utils.Intn(total)
|
||||
|
||||
// get channel by weight
|
||||
for _, channel := range stack {
|
||||
cursor -= channel.GetWeight()
|
||||
if cursor < 0 {
|
||||
return channel
|
||||
}
|
||||
}
|
||||
|
||||
return stack[0]
|
||||
}
|
||||
|
||||
func (t *Ticker) Next() *Channel {
|
||||
if t.Cursor >= len(t.Sequence) {
|
||||
// out of sequence
|
||||
return nil
|
||||
}
|
||||
|
||||
priority := t.Sequence[t.Cursor].GetPriority()
|
||||
channel := t.GetChannelByPriority(priority)
|
||||
t.SkipPriority(priority)
|
||||
|
||||
return channel
|
||||
}
|
||||
|
||||
func (t *Ticker) SkipPriority(priority int) {
|
||||
for idx, channel := range t.Sequence {
|
||||
if channel.GetPriority() == priority {
|
||||
// get if the next channel does not have the same priority or out of sequence
|
||||
if idx+1 >= len(t.Sequence) || t.Sequence[idx+1].GetPriority() != priority {
|
||||
t.Cursor = idx + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Ticker) Skip() {
|
||||
t.Cursor++
|
||||
}
|
||||
|
||||
func (t *Ticker) IsDone() bool {
|
||||
return t.Cursor >= len(t.Sequence)
|
||||
}
|
@ -1,24 +1,30 @@
|
||||
package channel
|
||||
|
||||
type Channel struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Priority int `json:"priority"`
|
||||
Weight int `json:"weight"`
|
||||
Models []string `json:"models"`
|
||||
Retry int `json:"retry"`
|
||||
Secret string `json:"secret"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Mapper string `json:"mapper"`
|
||||
State bool `json:"state"`
|
||||
|
||||
Reflect *map[string]string `json:"reflect"`
|
||||
HitModels *[]string `json:"hit_models"`
|
||||
Id int `json:"id" mapstructure:"id"`
|
||||
Name string `json:"name" mapstructure:"name"`
|
||||
Type string `json:"type" mapstructure:"type"`
|
||||
Priority int `json:"priority" mapstructure:"priority"`
|
||||
Weight int `json:"weight" mapstructure:"weight"`
|
||||
Models []string `json:"models" mapstructure:"models"`
|
||||
Retry int `json:"retry" mapstructure:"retry"`
|
||||
Secret string `json:"secret" mapstructure:"secret"`
|
||||
Endpoint string `json:"endpoint" mapstructure:"endpoint"`
|
||||
Mapper string `json:"mapper" mapstructure:"mapper"`
|
||||
State bool `json:"state" mapstructure:"state"`
|
||||
Reflect *map[string]string `json:"reflect" mapstructure:"reflect"`
|
||||
HitModels *[]string `json:"hit_models" mapstructure:"hit_models"`
|
||||
}
|
||||
|
||||
type Sequence []*Channel
|
||||
|
||||
type Manager struct {
|
||||
Sequence Sequence `json:"sequence"`
|
||||
PreflightSequence map[string]Sequence `json:"preflight_sequence"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type Ticker struct {
|
||||
Sequence Sequence `json:"sequence"`
|
||||
Cursor int `json:"cursor"`
|
||||
}
|
||||
|
29
channel/worker.go
Normal file
29
channel/worker.go
Normal file
@ -0,0 +1,29 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error {
|
||||
if !ManagerInstance.HasChannel(props.Model) {
|
||||
return fmt.Errorf("cannot find channel for model %s", props.Model)
|
||||
}
|
||||
|
||||
ticker := ManagerInstance.GetTicker(props.Model)
|
||||
|
||||
var err error
|
||||
for !ticker.IsDone() {
|
||||
if channel := ticker.Next(); channel != nil {
|
||||
props.MaxRetries = utils.ToPtr(channel.GetRetry())
|
||||
if err = adapter.NewChatRequest(channel, props, hook); err == nil {
|
||||
return nil
|
||||
}
|
||||
fmt.Println(fmt.Sprintf("[channel] hit error %s for model %s, goto next channel", err.Error(), props.Model))
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
@ -8,18 +8,18 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
OpenAIChannelType = iota
|
||||
ClaudeChannelType
|
||||
SlackChannelType
|
||||
SparkdeskChannelType
|
||||
ChatGLMChannelType
|
||||
DashscopeChannelType
|
||||
HunyuanChannelType
|
||||
ZhinaoChannelType
|
||||
BaichuanChannelType
|
||||
SkylarkChannelType
|
||||
BingChannelType
|
||||
PalmChannelType
|
||||
MidjourneyChannelType
|
||||
OneAPIChannelType
|
||||
OpenAIChannelType = "openai"
|
||||
ClaudeChannelType = "claude"
|
||||
SlackChannelType = "slack"
|
||||
SparkdeskChannelType = "sparkdesk"
|
||||
ChatGLMChannelType = "chatglm"
|
||||
QwenChannelType = "qwen"
|
||||
HunyuanChannelType = "hunyuan"
|
||||
ZhinaoChannelType = "zhinao"
|
||||
BaichuanChannelType = "baichuan"
|
||||
SkylarkChannelType = "skylark"
|
||||
BingChannelType = "bing"
|
||||
PalmChannelType = "palm"
|
||||
MidjourneyChannelType = "midjourney"
|
||||
OneAPIChannelType = "oneapi"
|
||||
)
|
||||
|
10
globals/interface.go
Normal file
10
globals/interface.go
Normal file
@ -0,0 +1,10 @@
|
||||
package globals
|
||||
|
||||
type ChannelConfig interface {
|
||||
GetType() string
|
||||
GetModelReflect(model string) string
|
||||
GetRetry() int
|
||||
GetRandomSecret() string
|
||||
GetEndpoint() string
|
||||
ProcessError(err error) error
|
||||
}
|
2
main.go
2
main.go
@ -5,6 +5,7 @@ import (
|
||||
"chat/addition"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/channel"
|
||||
"chat/cli"
|
||||
"chat/manager"
|
||||
"chat/manager/conversation"
|
||||
@ -23,6 +24,7 @@ func main() {
|
||||
if cli.Run() {
|
||||
return
|
||||
}
|
||||
channel.InitManager()
|
||||
|
||||
app := gin.Default()
|
||||
middleware.RegisterMiddleware(app)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"chat/addition/web"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/channel"
|
||||
"chat/globals"
|
||||
"chat/manager/conversation"
|
||||
"chat/utils"
|
||||
@ -86,7 +87,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment)
|
||||
err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
err := channel.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Message: segment,
|
||||
Plan: plan,
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"chat/addition/web"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/channel"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
@ -39,7 +40,7 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment)
|
||||
err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
err := channel.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Plan: plan,
|
||||
Message: segment,
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"chat/addition/web"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/channel"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
@ -70,7 +71,7 @@ type TranshipmentStreamResponse struct {
|
||||
}
|
||||
|
||||
func ModelAPI(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, globals.AllModels)
|
||||
c.JSON(http.StatusOK, channel.ManagerInstance.GetModels())
|
||||
}
|
||||
|
||||
func TranshipmentAPI(c *gin.Context) {
|
||||
@ -162,7 +163,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
buffer := utils.NewBuffer(form.Model, form.Messages)
|
||||
err := adapter.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
|
||||
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
|
||||
buffer.Write(data)
|
||||
return nil
|
||||
})
|
||||
@ -221,34 +222,34 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
|
||||
}
|
||||
|
||||
func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, plan bool) {
|
||||
channel := make(chan TranshipmentStreamResponse)
|
||||
partial := make(chan TranshipmentStreamResponse)
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
go func() {
|
||||
buffer := utils.NewBuffer(form.Model, form.Messages)
|
||||
err := adapter.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
|
||||
channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
|
||||
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
|
||||
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
|
||||
return nil
|
||||
})
|
||||
|
||||
admin.AnalysisRequest(form.Model, buffer, err)
|
||||
if err != nil {
|
||||
auth.RevertSubscriptionUsage(db, cache, user, form.Model)
|
||||
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
|
||||
partial <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
|
||||
CollectQuota(c, user, buffer, plan)
|
||||
close(channel)
|
||||
close(partial)
|
||||
return
|
||||
}
|
||||
|
||||
channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
|
||||
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
|
||||
CollectQuota(c, user, buffer, plan)
|
||||
close(channel)
|
||||
close(partial)
|
||||
return
|
||||
}()
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
if resp, ok := <-channel; ok {
|
||||
if resp, ok := <-partial; ok {
|
||||
c.Render(-1, utils.NewEvent(resp))
|
||||
return true
|
||||
}
|
||||
|
@ -3,9 +3,16 @@ package utils
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/goccy/go-json"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Intn(n int) int {
|
||||
source := rand.NewSource(time.Now().UnixNano())
|
||||
r := rand.New(source)
|
||||
return r.Intn(n)
|
||||
}
|
||||
|
||||
func Sum[T int | int64 | float32 | float64](arr []T) T {
|
||||
var res T
|
||||
for _, v := range arr {
|
||||
|
@ -3,7 +3,6 @@ package utils
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/goccy/go-json"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -11,13 +10,13 @@ import (
|
||||
)
|
||||
|
||||
func GetRandomInt(min int, max int) int {
|
||||
return rand.Intn(max-min) + min
|
||||
return Intn(max-min) + min
|
||||
}
|
||||
|
||||
func GenerateCode(length int) string {
|
||||
var code string
|
||||
for i := 0; i < length; i++ {
|
||||
code += strconv.Itoa(rand.Intn(10))
|
||||
code += strconv.Itoa(Intn(10))
|
||||
}
|
||||
return code
|
||||
}
|
||||
@ -26,7 +25,7 @@ func GenerateChar(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = charset[rand.Intn(len(charset))]
|
||||
result[i] = charset[Intn(len(charset))]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
12
utils/key.go
12
utils/key.go
@ -1,12 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetRandomKey(apikey string) string {
|
||||
arr := strings.Split(apikey, "|")
|
||||
idx := rand.Intn(len(arr))
|
||||
return arr[idx]
|
||||
}
|
Loading…
Reference in New Issue
Block a user