update: channel worker, channel sequence and ticker

This commit is contained in:
Zhang Minghan 2023-12-01 22:48:25 +08:00
parent db7acee643
commit 8e3a424e60
25 changed files with 478 additions and 192 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@ node_modules
.vscode
.idea
config.yaml
config.dev.yaml
addition/generation/data/*
!addition/generation/data/.gitkeep

View File

@ -17,11 +17,9 @@ import (
"chat/adapter/zhipuai"
"chat/globals"
"chat/utils"
"fmt"
)
var defaultMaxRetries = 3
var midjourneyMaxRetries = 10
type RequestProps struct {
MaxRetries *int
Current int
@ -46,36 +44,21 @@ type ChatProps struct {
Buffer utils.Buffer
}
func createChatRequest(props *ChatProps, hook globals.Hook) error {
if oneapi.IsHit(props.Model) {
return oneapi.NewChatInstanceFromConfig().CreateStreamChatRequest(&oneapi.ChatProps{
Model: props.Model,
Message: props.Message,
Token: utils.Multi(
props.Token == 0,
utils.Multi(globals.IsGPT4Model(props.Model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
&props.Token,
),
PresencePenalty: props.PresencePenalty,
FrequencyPenalty: props.FrequencyPenalty,
Temperature: props.Temperature,
TopP: props.TopP,
Tools: props.Tools,
ToolChoice: props.ToolChoice,
Buffer: props.Buffer,
}, hook)
func createChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error {
model := conf.GetModelReflect(props.Model)
} else if globals.IsChatGPTModel(props.Model) {
switch conf.GetType() {
case globals.OpenAIChannelType:
instance := chatgpt.NewChatInstanceFromModel(&chatgpt.InstanceProps{
Model: props.Model,
Model: model,
Plan: props.Plan,
})
return instance.CreateStreamChatRequest(&chatgpt.ChatProps{
Model: props.Model,
Model: model,
Message: props.Message,
Token: utils.Multi(
props.Token == 0,
utils.Multi(globals.IsGPT4Model(props.Model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
utils.Multi(globals.IsGPT4Model(model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
&props.Token,
),
PresencePenalty: props.PresencePenalty,
@ -87,9 +70,9 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
Buffer: props.Buffer,
}, hook)
} else if globals.IsClaudeModel(props.Model) {
case globals.ClaudeChannelType:
return claude.NewChatInstanceFromConfig().CreateStreamChatRequest(&claude.ChatProps{
Model: props.Model,
Model: model,
Message: props.Message,
Token: utils.Multi(props.Token == 0, 50000, props.Token),
TopP: props.TopP,
@ -97,9 +80,26 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
Temperature: props.Temperature,
}, hook)
} else if globals.IsSparkDeskModel(props.Model) {
return sparkdesk.NewChatInstance(props.Model).CreateStreamChatRequest(&sparkdesk.ChatProps{
Model: props.Model,
case globals.SlackChannelType:
return slack.NewChatInstanceFromConfig().CreateStreamChatRequest(&slack.ChatProps{
Message: props.Message,
}, hook)
case globals.BingChannelType:
return bing.NewChatInstanceFromConfig().CreateStreamChatRequest(&bing.ChatProps{
Model: model,
Message: props.Message,
}, hook)
case globals.PalmChannelType:
return palm2.NewChatInstanceFromConfig().CreateStreamChatRequest(&palm2.ChatProps{
Model: model,
Message: props.Message,
}, hook)
case globals.SparkdeskChannelType:
return sparkdesk.NewChatInstance(model).CreateStreamChatRequest(&sparkdesk.ChatProps{
Model: model,
Message: props.Message,
Token: utils.Multi(props.Token == 0, nil, utils.ToPtr(props.Token)),
Temperature: props.Temperature,
@ -108,34 +108,17 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
Buffer: props.Buffer,
}, hook)
} else if globals.IsPalm2Model(props.Model) {
return palm2.NewChatInstanceFromConfig().CreateStreamChatRequest(&palm2.ChatProps{
Model: props.Model,
Message: props.Message,
}, hook)
} else if globals.IsSlackModel(props.Model) {
return slack.NewChatInstanceFromConfig().CreateStreamChatRequest(&slack.ChatProps{
Message: props.Message,
}, hook)
} else if globals.IsBingModel(props.Model) {
return bing.NewChatInstanceFromConfig().CreateStreamChatRequest(&bing.ChatProps{
Model: props.Model,
Message: props.Message,
}, hook)
} else if globals.IsZhiPuModel(props.Model) {
case globals.ChatGLMChannelType:
return zhipuai.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhipuai.ChatProps{
Model: props.Model,
Model: model,
Message: props.Message,
Temperature: props.Temperature,
TopP: props.TopP,
}, hook)
} else if globals.IsQwenModel(props.Model) {
case globals.QwenChannelType:
return dashscope.NewChatInstanceFromConfig().CreateStreamChatRequest(&dashscope.ChatProps{
Model: props.Model,
Model: model,
Message: props.Message,
Token: utils.Multi(props.Infinity || props.Plan, 2048, props.Token),
Temperature: props.Temperature,
@ -144,43 +127,26 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
RepetitionPenalty: props.RepetitionPenalty,
}, hook)
} else if globals.IsMidjourneyModel(props.Model) {
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
Model: props.Model,
Messages: props.Message,
}, hook)
} else if globals.IsHunyuanModel(props.Model) {
case globals.HunyuanChannelType:
return hunyuan.NewChatInstanceFromConfig().CreateStreamChatRequest(&hunyuan.ChatProps{
Model: props.Model,
Model: model,
Message: props.Message,
Temperature: props.Temperature,
TopP: props.TopP,
}, hook)
} else if globals.Is360Model(props.Model) {
return zhinao.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhinao.ChatProps{
Model: props.Model,
Message: props.Message,
Token: utils.Multi(props.Infinity || props.Plan, nil, utils.ToPtr(2048)),
TopP: props.TopP,
TopK: props.TopK,
Temperature: props.Temperature,
RepetitionPenalty: props.RepetitionPenalty,
}, hook)
} else if globals.IsBaichuanModel(props.Model) {
case globals.BaichuanChannelType:
return baichuan.NewChatInstanceFromConfig().CreateStreamChatRequest(&baichuan.ChatProps{
Model: props.Model,
Model: model,
Message: props.Message,
TopP: props.TopP,
TopK: props.TopK,
Temperature: props.Temperature,
}, hook)
} else if globals.IsSkylarkModel(props.Model) {
case globals.SkylarkChannelType:
return skylark.NewChatInstanceFromConfig().CreateStreamChatRequest(&skylark.ChatProps{
Model: props.Model,
Model: model,
Message: props.Message,
Token: utils.Multi(props.Token == 0, 4096, props.Token),
TopP: props.TopP,
@ -191,7 +157,43 @@ func createChatRequest(props *ChatProps, hook globals.Hook) error {
RepeatPenalty: props.RepetitionPenalty,
Tools: props.Tools,
}, hook)
}
return hook("Sorry, we cannot find the model you are looking for. Please try another model.")
case globals.ZhinaoChannelType:
return zhinao.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhinao.ChatProps{
Model: model,
Message: props.Message,
Token: utils.Multi(props.Infinity || props.Plan, nil, utils.ToPtr(2048)),
TopP: props.TopP,
TopK: props.TopK,
Temperature: props.Temperature,
RepetitionPenalty: props.RepetitionPenalty,
}, hook)
case globals.MidjourneyChannelType:
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
Model: model,
Messages: props.Message,
}, hook)
case globals.OneAPIChannelType:
return oneapi.NewChatInstanceFromConfig().CreateStreamChatRequest(&oneapi.ChatProps{
Model: model,
Message: props.Message,
Token: utils.Multi(
props.Token == 0,
utils.Multi(globals.IsGPT4Model(model) || props.Plan || props.Infinity, nil, utils.ToPtr(2500)),
&props.Token,
),
PresencePenalty: props.PresencePenalty,
FrequencyPenalty: props.FrequencyPenalty,
Temperature: props.Temperature,
TopP: props.TopP,
Tools: props.Tools,
ToolChoice: props.ToolChoice,
Buffer: props.Buffer,
}, hook)
default:
return fmt.Errorf("unknown channel type %s for model %s", conf.GetType(), props.Model)
}
}

View File

@ -2,7 +2,6 @@ package chatgpt
import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/spf13/viper"
)
@ -42,7 +41,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance {
func NewChatInstanceFromConfig(v string) *ChatInstance {
return NewChatInstance(
viper.GetString(fmt.Sprintf("openai.%s.endpoint", v)),
utils.GetRandomKey(viper.GetString(fmt.Sprintf("openai.%s.apikey", v))),
viper.GetString(fmt.Sprintf("openai.%s.apikey", v)),
)
}

View File

@ -1,7 +1,6 @@
package claude
import (
"chat/utils"
"github.com/spf13/viper"
)
@ -20,7 +19,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance {
func NewChatInstanceFromConfig() *ChatInstance {
return NewChatInstance(
viper.GetString("claude.endpoint"),
utils.GetRandomKey(viper.GetString("claude.apikey")),
viper.GetString("claude.apikey"),
)
}

View File

@ -1,7 +1,6 @@
package dashscope
import (
"chat/utils"
"github.com/spf13/viper"
)
@ -28,6 +27,6 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance {
func NewChatInstanceFromConfig() *ChatInstance {
return NewChatInstance(
viper.GetString("dashscope.endpoint"),
utils.GetRandomKey(viper.GetString("dashscope.apikey")),
viper.GetString("dashscope.apikey"),
)
}

View File

@ -1,26 +0,0 @@
package oneapi
import (
"chat/globals"
)
var HitModels = []string{
globals.Claude1, globals.Claude1100k,
globals.Claude2, globals.Claude2100k,
globals.StableDiffusion,
globals.LLaMa270B, globals.LLaMa213B, globals.LLaMa27B,
globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B,
}
func (c *ChatInstance) GetToken(model string) int {
switch model {
case globals.Claude1, globals.Claude2:
return 5000
case globals.Claude2100k, globals.Claude1100k:
return 50000
case globals.LLaMa270B, globals.LLaMa213B, globals.LLaMa27B, globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B:
return 3000
default:
return 2500
}
}

View File

@ -1,7 +1,6 @@
package oneapi
import (
"chat/utils"
"fmt"
"github.com/spf13/viper"
)
@ -44,7 +43,3 @@ func NewChatInstanceFromConfig() *ChatInstance {
viper.GetString("oneapi.apikey"),
)
}
func IsHit(model string) bool {
return utils.Contains[string](model, HitModels)
}

View File

@ -1,7 +1,6 @@
package palm2
import (
"chat/utils"
"github.com/spf13/viper"
)
@ -28,6 +27,6 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance {
func NewChatInstanceFromConfig() *ChatInstance {
return NewChatInstance(
viper.GetString("palm2.endpoint"),
utils.GetRandomKey(viper.GetString("palm2.apikey")),
viper.GetString("palm2.apikey"),
)
}

View File

@ -20,21 +20,10 @@ func isQPSOverLimit(model string, err error) bool {
}
}
func getRetries(model string, retries *int) int {
if retries == nil {
if globals.IsMidjourneyModel(model) {
return midjourneyMaxRetries
}
return defaultMaxRetries
}
func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error {
err := createChatRequest(conf, props, hook)
return *retries
}
func NewChatRequest(props *ChatProps, hook globals.Hook) error {
err := createChatRequest(props, hook)
retries := getRetries(props.Model, props.MaxRetries)
retries := conf.GetRetry()
props.Current++
if IsAvailableError(err) {
@ -43,14 +32,14 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
fmt.Println(fmt.Sprintf("qps limit for %s, sleep and retry (times: %d)", props.Model, props.Current))
time.Sleep(500 * time.Millisecond)
return NewChatRequest(props, hook)
return NewChatRequest(conf, props, hook)
}
if props.Current < retries {
fmt.Println(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, err.Error()))
return NewChatRequest(props, hook)
return NewChatRequest(conf, props, hook)
}
}
return err
return conf.ProcessError(err)
}

View File

@ -3,6 +3,7 @@ package generation
import (
"chat/adapter"
"chat/admin"
"chat/channel"
"chat/globals"
"chat/utils"
"fmt"
@ -16,7 +17,7 @@ func CreateGeneration(model string, prompt string, path string, plan bool, hook
message := GenerateMessage(prompt)
buffer := utils.NewBuffer(model, message)
err := adapter.NewChatRequest(&adapter.ChatProps{
err := channel.NewChatRequest(&adapter.ChatProps{
Model: model,
Message: message,
Plan: plan,

View File

@ -2,7 +2,8 @@ package channel
import (
"chat/utils"
"math/rand"
"errors"
"fmt"
"strings"
)
@ -46,9 +47,10 @@ func (c *Channel) GetSecret() string {
return c.Secret
}
// GetRandomSecret returns a random secret from the secret list
func (c *Channel) GetRandomSecret() string {
arr := strings.Split(c.GetSecret(), "\n")
idx := rand.Intn(len(arr))
idx := utils.Intn(len(arr))
return arr[idx]
}
@ -77,6 +79,7 @@ func (c *Channel) GetReflect() map[string]string {
return *c.Reflect
}
// GetModelReflect returns the reflection model name if it exists, otherwise returns the original model name
func (c *Channel) GetModelReflect(model string) string {
ref := c.GetReflect()
if reflect, ok := ref[model]; ok && len(reflect) > 0 {
@ -114,3 +117,21 @@ func (c *Channel) GetHitModels() []string {
func (c *Channel) GetState() bool {
return c.State
}
func (c *Channel) IsHit(model string) bool {
return utils.Contains(model, c.GetHitModels())
}
func (c *Channel) ProcessError(err error) error {
if err == nil {
return nil
}
content := err.Error()
if strings.Contains(content, c.GetEndpoint()) {
// hide the endpoint
replacer := fmt.Sprintf("channel://%d", c.GetId())
content = strings.Replace(content, c.GetEndpoint(), replacer, -1)
}
return errors.New(content)
}

153
channel/manager.go Normal file
View File

@ -0,0 +1,153 @@
package channel
import (
"chat/utils"
"github.com/spf13/viper"
)
var ManagerInstance *Manager
func InitManager() {
ManagerInstance = NewManager()
}
func NewManager() *Manager {
var seq Sequence
if err := viper.UnmarshalKey("channel", &seq); err != nil {
panic(err)
}
// sort by priority
manager := &Manager{
Sequence: seq,
}
return manager
}
func (m *Manager) Init() {
// init support models
for _, channel := range m.GetActiveSequence() {
for _, model := range channel.GetModels() {
if !utils.Contains(model, m.Models) {
m.Models = append(m.Models, model)
}
}
}
// init preflight sequence
for _, model := range m.Models {
var seq Sequence
for _, channel := range m.GetActiveSequence() {
if utils.Contains(model, channel.GetModels()) {
seq = append(seq, channel)
}
}
seq.Sort()
m.PreflightSequence[model] = seq
}
}
func (m *Manager) GetSequence() Sequence {
return m.Sequence
}
func (m *Manager) GetActiveSequence() Sequence {
var seq Sequence
for _, channel := range m.Sequence {
if channel.GetState() {
seq = append(seq, channel)
}
}
seq.Sort()
return seq
}
func (m *Manager) GetModels() []string {
return m.Models
}
func (m *Manager) GetPreflightSequence() map[string]Sequence {
return m.PreflightSequence
}
// HitSequence returns the preflight sequence of the model
func (m *Manager) HitSequence(model string) Sequence {
return m.PreflightSequence[model]
}
// HasChannel returns whether the channel exists
func (m *Manager) HasChannel(model string) bool {
return utils.Contains(model, m.Models)
}
func (m *Manager) GetTicker(model string) *Ticker {
return &Ticker{
Sequence: m.HitSequence(model),
}
}
func (m *Manager) Len() int {
return len(m.Sequence)
}
func (m *Manager) GetMaxId() int {
var max int
for _, channel := range m.Sequence {
if channel.Id > max {
max = channel.Id
}
}
return max
}
func (m *Manager) SaveConfig() error {
return viper.UnmarshalKey("channel", &m.Sequence)
}
func (m *Manager) CreateChannel(channel *Channel) error {
channel.Id = m.GetMaxId() + 1
m.Sequence = append(m.Sequence, channel)
return m.SaveConfig()
}
func (m *Manager) UpdateChannel(id int, channel *Channel) error {
for i, item := range m.Sequence {
if item.Id == id {
m.Sequence[i] = channel
return m.SaveConfig()
}
}
return nil
}
func (m *Manager) DeleteChannel(id int) error {
for i, item := range m.Sequence {
if item.Id == id {
m.Sequence = append(m.Sequence[:i], m.Sequence[i+1:]...)
return m.SaveConfig()
}
}
return nil
}
func (m *Manager) ActivateChannel(id int) error {
for i, item := range m.Sequence {
if item.Id == id {
m.Sequence[i].State = true
return m.SaveConfig()
}
}
return nil
}
func (m *Manager) DeactivateChannel(id int) error {
for i, item := range m.Sequence {
if item.Id == id {
m.Sequence[i].State = false
return m.SaveConfig()
}
}
return nil
}

29
channel/sequence.go Normal file
View File

@ -0,0 +1,29 @@
package channel
import "sort"
func (s *Sequence) Len() int {
return len(*s)
}
func (s *Sequence) Less(i, j int) bool {
return (*s)[i].GetPriority() < (*s)[j].GetPriority()
}
func (s *Sequence) Swap(i, j int) {
(*s)[i], (*s)[j] = (*s)[j], (*s)[i]
}
func (s *Sequence) GetChannelById(id int) *Channel {
for _, channel := range *s {
if channel.Id == id {
return channel
}
}
return nil
}
func (s *Sequence) Sort() {
// sort by priority
sort.Sort(s)
}

81
channel/ticker.go Normal file
View File

@ -0,0 +1,81 @@
package channel
import "chat/utils"
func (t *Ticker) GetChannelByPriority(priority int) *Channel {
var stack Sequence
for idx, channel := range t.Sequence {
if channel.GetPriority() == priority {
// get if the next channel has the same priority
if idx+1 < len(t.Sequence) && t.Sequence[idx+1].GetPriority() == priority {
stack = append(stack, channel)
continue
}
if len(stack) == 0 {
return channel
}
// stack is not empty
stack = append(stack, channel)
// sort by weight and break the loop
if idx+1 >= len(t.Sequence) || t.Sequence[idx+1].GetPriority() != priority {
stack.Sort()
break
}
}
}
weight := utils.Each(stack, func(channel *Channel) int {
return channel.GetWeight()
})
total := utils.Sum(weight)
// get random number
cursor := utils.Intn(total)
// get channel by weight
for _, channel := range stack {
cursor -= channel.GetWeight()
if cursor < 0 {
return channel
}
}
return stack[0]
}
func (t *Ticker) Next() *Channel {
if t.Cursor >= len(t.Sequence) {
// out of sequence
return nil
}
priority := t.Sequence[t.Cursor].GetPriority()
channel := t.GetChannelByPriority(priority)
t.SkipPriority(priority)
return channel
}
func (t *Ticker) SkipPriority(priority int) {
for idx, channel := range t.Sequence {
if channel.GetPriority() == priority {
// get if the next channel does not have the same priority or out of sequence
if idx+1 >= len(t.Sequence) || t.Sequence[idx+1].GetPriority() != priority {
t.Cursor = idx + 1
break
}
}
}
}
func (t *Ticker) Skip() {
t.Cursor++
}
func (t *Ticker) IsDone() bool {
return t.Cursor >= len(t.Sequence)
}

View File

@ -1,24 +1,30 @@
package channel
type Channel struct {
Id int `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Priority int `json:"priority"`
Weight int `json:"weight"`
Models []string `json:"models"`
Retry int `json:"retry"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
Mapper string `json:"mapper"`
State bool `json:"state"`
Reflect *map[string]string `json:"reflect"`
HitModels *[]string `json:"hit_models"`
Id int `json:"id" mapstructure:"id"`
Name string `json:"name" mapstructure:"name"`
Type string `json:"type" mapstructure:"type"`
Priority int `json:"priority" mapstructure:"priority"`
Weight int `json:"weight" mapstructure:"weight"`
Models []string `json:"models" mapstructure:"models"`
Retry int `json:"retry" mapstructure:"retry"`
Secret string `json:"secret" mapstructure:"secret"`
Endpoint string `json:"endpoint" mapstructure:"endpoint"`
Mapper string `json:"mapper" mapstructure:"mapper"`
State bool `json:"state" mapstructure:"state"`
Reflect *map[string]string `json:"reflect" mapstructure:"reflect"`
HitModels *[]string `json:"hit_models" mapstructure:"hit_models"`
}
type Sequence []*Channel
type Manager struct {
Sequence Sequence `json:"sequence"`
PreflightSequence map[string]Sequence `json:"preflight_sequence"`
Models []string `json:"models"`
}
type Ticker struct {
Sequence Sequence `json:"sequence"`
Cursor int `json:"cursor"`
}

29
channel/worker.go Normal file
View File

@ -0,0 +1,29 @@
package channel
import (
"chat/adapter"
"chat/globals"
"chat/utils"
"fmt"
)
func NewChatRequest(props *adapter.ChatProps, hook globals.Hook) error {
if !ManagerInstance.HasChannel(props.Model) {
return fmt.Errorf("cannot find channel for model %s", props.Model)
}
ticker := ManagerInstance.GetTicker(props.Model)
var err error
for !ticker.IsDone() {
if channel := ticker.Next(); channel != nil {
props.MaxRetries = utils.ToPtr(channel.GetRetry())
if err = adapter.NewChatRequest(channel, props, hook); err == nil {
return nil
}
fmt.Println(fmt.Sprintf("[channel] hit error %s for model %s, goto next channel", err.Error(), props.Model))
}
}
return err
}

View File

@ -8,18 +8,18 @@ const (
)
const (
OpenAIChannelType = iota
ClaudeChannelType
SlackChannelType
SparkdeskChannelType
ChatGLMChannelType
DashscopeChannelType
HunyuanChannelType
ZhinaoChannelType
BaichuanChannelType
SkylarkChannelType
BingChannelType
PalmChannelType
MidjourneyChannelType
OneAPIChannelType
OpenAIChannelType = "openai"
ClaudeChannelType = "claude"
SlackChannelType = "slack"
SparkdeskChannelType = "sparkdesk"
ChatGLMChannelType = "chatglm"
QwenChannelType = "qwen"
HunyuanChannelType = "hunyuan"
ZhinaoChannelType = "zhinao"
BaichuanChannelType = "baichuan"
SkylarkChannelType = "skylark"
BingChannelType = "bing"
PalmChannelType = "palm"
MidjourneyChannelType = "midjourney"
OneAPIChannelType = "oneapi"
)

10
globals/interface.go Normal file
View File

@ -0,0 +1,10 @@
package globals
type ChannelConfig interface {
GetType() string
GetModelReflect(model string) string
GetRetry() int
GetRandomSecret() string
GetEndpoint() string
ProcessError(err error) error
}

View File

@ -5,6 +5,7 @@ import (
"chat/addition"
"chat/admin"
"chat/auth"
"chat/channel"
"chat/cli"
"chat/manager"
"chat/manager/conversation"
@ -23,6 +24,7 @@ func main() {
if cli.Run() {
return
}
channel.InitManager()
app := gin.Default()
middleware.RegisterMiddleware(app)

View File

@ -5,6 +5,7 @@ import (
"chat/addition/web"
"chat/admin"
"chat/auth"
"chat/channel"
"chat/globals"
"chat/manager/conversation"
"chat/utils"
@ -86,7 +87,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
}
buffer := utils.NewBuffer(model, segment)
err := adapter.NewChatRequest(&adapter.ChatProps{
err := channel.NewChatRequest(&adapter.ChatProps{
Model: model,
Message: segment,
Plan: plan,

View File

@ -5,6 +5,7 @@ import (
"chat/addition/web"
"chat/admin"
"chat/auth"
"chat/channel"
"chat/globals"
"chat/utils"
"fmt"
@ -39,7 +40,7 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
}
buffer := utils.NewBuffer(model, segment)
err := adapter.NewChatRequest(&adapter.ChatProps{
err := channel.NewChatRequest(&adapter.ChatProps{
Model: model,
Plan: plan,
Message: segment,

View File

@ -5,6 +5,7 @@ import (
"chat/addition/web"
"chat/admin"
"chat/auth"
"chat/channel"
"chat/globals"
"chat/utils"
"fmt"
@ -70,7 +71,7 @@ type TranshipmentStreamResponse struct {
}
func ModelAPI(c *gin.Context) {
c.JSON(http.StatusOK, globals.AllModels)
c.JSON(http.StatusOK, channel.ManagerInstance.GetModels())
}
func TranshipmentAPI(c *gin.Context) {
@ -162,7 +163,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
cache := utils.GetCacheFromContext(c)
buffer := utils.NewBuffer(form.Model, form.Messages)
err := adapter.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
buffer.Write(data)
return nil
})
@ -221,34 +222,34 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
}
func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, plan bool) {
channel := make(chan TranshipmentStreamResponse)
partial := make(chan TranshipmentStreamResponse)
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
go func() {
buffer := utils.NewBuffer(form.Model, form.Messages)
err := adapter.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
return nil
})
admin.AnalysisRequest(form.Model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model)
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
partial <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
CollectQuota(c, user, buffer, plan)
close(channel)
close(partial)
return
}
channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
CollectQuota(c, user, buffer, plan)
close(channel)
close(partial)
return
}()
c.Stream(func(w io.Writer) bool {
if resp, ok := <-channel; ok {
if resp, ok := <-partial; ok {
c.Render(-1, utils.NewEvent(resp))
return true
}

View File

@ -3,9 +3,16 @@ package utils
import (
"fmt"
"github.com/goccy/go-json"
"math/rand"
"time"
)
func Intn(n int) int {
source := rand.NewSource(time.Now().UnixNano())
r := rand.New(source)
return r.Intn(n)
}
func Sum[T int | int64 | float32 | float64](arr []T) T {
var res T
for _, v := range arr {

View File

@ -3,7 +3,6 @@ package utils
import (
"fmt"
"github.com/goccy/go-json"
"math/rand"
"regexp"
"strconv"
"strings"
@ -11,13 +10,13 @@ import (
)
func GetRandomInt(min int, max int) int {
return rand.Intn(max-min) + min
return Intn(max-min) + min
}
func GenerateCode(length int) string {
var code string
for i := 0; i < length; i++ {
code += strconv.Itoa(rand.Intn(10))
code += strconv.Itoa(Intn(10))
}
return code
}
@ -26,7 +25,7 @@ func GenerateChar(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := 0; i < length; i++ {
result[i] = charset[rand.Intn(len(charset))]
result[i] = charset[Intn(len(charset))]
}
return string(result)
}

View File

@ -1,12 +0,0 @@
package utils
import (
"math/rand"
"strings"
)
func GetRandomKey(apikey string) string {
arr := strings.Split(apikey, "|")
idx := rand.Intn(len(arr))
return arr[idx]
}