diff --git a/addition/generation/api.go b/addition/generation/api.go index 379ad07..e9797ca 100644 --- a/addition/generation/api.go +++ b/addition/generation/api.go @@ -61,8 +61,8 @@ func GenerateAPI(c *gin.Context) { return } - reversible := globals.IsGPT4NativeModel(form.Model) && auth.CanEnableSubscription(db, cache, user) - if !auth.CanEnableModelWithSubscription(db, user, form.Model, reversible) { + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model) + if !check { conn.Send(globals.GenerationSegmentResponse{ Message: "You don't have enough quota to use this model.", Quota: 0, @@ -72,7 +72,7 @@ func GenerateAPI(c *gin.Context) { } var instance *utils.Buffer - hash, err := CreateGenerationWithCache(form.Model, form.Prompt, reversible, func(buffer *utils.Buffer, data string) { + hash, err := CreateGenerationWithCache(form.Model, form.Prompt, plan, func(buffer *utils.Buffer, data string) { instance = buffer conn.Send(globals.GenerationSegmentResponse{ End: false, @@ -81,7 +81,7 @@ func GenerateAPI(c *gin.Context) { }) }) - if instance != nil && !reversible && instance.GetQuota() > 0 && user != nil { + if instance != nil && !plan && instance.GetQuota() > 0 && user != nil { user.UseQuota(db, instance.GetQuota()) } diff --git a/addition/generation/prompt.go b/addition/generation/prompt.go index 731c7b4..e88b8f1 100644 --- a/addition/generation/prompt.go +++ b/addition/generation/prompt.go @@ -11,14 +11,14 @@ type ProjectResult struct { Result map[string]interface{} `json:"result"` } -func CreateGeneration(model string, prompt string, path string, reversible bool, hook func(buffer *utils.Buffer, data string)) error { +func CreateGeneration(model string, prompt string, path string, plan bool, hook func(buffer *utils.Buffer, data string)) error { message := GenerateMessage(prompt) buffer := utils.NewBuffer(model, message) if err := adapter.NewChatRequest(&adapter.ChatProps{ Model: model, Message: message, - Reversible: reversible && globals.IsGPT4Model(model), + Reversible: plan, Infinity: true, }, func(data string) error { buffer.Write(data) diff --git a/auth/package.go b/auth/package.go index 2eb1ef7..8e21ccb 100644 --- a/auth/package.go +++ b/auth/package.go @@ -7,6 +7,23 @@ type GiftResponse struct { Teenager bool `json:"teenager"` } +func (u *User) HasPackage(db *sql.DB, _t string) bool { + var count int + if err := db.QueryRow(`SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, u.ID, _t).Scan(&count); err != nil { + return false + } + + return count > 0 +} + +func (u *User) HasCertPackage(db *sql.DB) bool { + return u.HasPackage(db, "cert") +} + +func (u *User) HasTeenagerPackage(db *sql.DB) bool { + return u.HasPackage(db, "teenager") +} + func NewPackage(db *sql.DB, user *User, _t string) bool { id := user.GetID(db) diff --git a/auth/payment.go b/auth/payment.go index 9a65ff1..3a0067c 100644 --- a/auth/payment.go +++ b/auth/payment.go @@ -1,7 +1,6 @@ package auth import ( - "chat/globals" "chat/utils" "database/sql" "encoding/json" @@ -72,35 +71,6 @@ func ReduceDalle(db *sql.DB, user *User) bool { return user.UseQuota(db, 1) } -func CanEnableModel(db *sql.DB, user *User, model string) bool { - switch model { - case globals.GPT3Turbo, globals.GPT3Turbo0301, globals.GPT3Turbo0613, - globals.Claude2: - return true - case globals.GPT4, globals.GPT40613, globals.GPT40314: - return user != nil && user.GetQuota(db) >= 5 - case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314: - return user != nil && user.GetQuota(db) >= 50 - case globals.SparkDesk: - return user != nil && user.GetQuota(db) >= 1 - case globals.Claude2100k: - return user != nil && user.GetQuota(db) >= 1 - case globals.ZhiPuChatGLMPro, globals.ZhiPuChatGLMStd: - return user != nil && user.GetQuota(db) >= 1 - default: - return user != nil - } -} - -func CanEnableModelWithSubscription(db *sql.DB, user *User, model string, useReverse bool) bool { - if utils.Contains(model, globals.GPT4Array) { - if useReverse { - return true - } - } - return CanEnableModel(db, user, model) -} - func BuyQuota(db *sql.DB, user *User, quota int) bool { money := float32(quota) * 0.1 if Pay(user.Username, money) { diff --git a/auth/rule.go b/auth/rule.go new file mode 100644 index 0000000..f9ae281 --- /dev/null +++ b/auth/rule.go @@ -0,0 +1,54 @@ +package auth + +import ( + "chat/globals" + "database/sql" + "github.com/go-redis/redis/v8" +) + +// CanEnableModel returns whether the model can be enabled (without subscription) +func CanEnableModel(db *sql.DB, user *User, model string) bool { + switch model { + case globals.GPT3Turbo, globals.GPT3Turbo0301, globals.GPT3Turbo0613, + globals.Claude2: + return true + case globals.GPT4, globals.GPT40613, globals.GPT40314: + return user != nil && user.GetQuota(db) >= 5 + case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314: + return user != nil && user.GetQuota(db) >= 50 + case globals.SparkDesk: + return user != nil && user.GetQuota(db) >= 1 + case globals.Claude2100k: + return user != nil && user.GetQuota(db) >= 1 + case globals.ZhiPuChatGLMPro, globals.ZhiPuChatGLMStd: + return user != nil && user.GetQuota(db) >= 1 + default: + return user != nil + } +} + +func HandleSubscriptionUsage(db *sql.DB, cache *redis.Client, user *User, model string) bool { + subscription := user.IsSubscribe(db) + if globals.IsGPT4NativeModel(model) { + return subscription && IncreaseSubscriptionUsage(cache, user, globals.GPT4, 50) + } else if model == globals.Claude2100k { + if subscription || user.HasTeenagerPackage(db) { + // free for subscription users and students + return true + } else { + // 30 100k quota for common users + return IncreaseSubscriptionUsage(cache, user, globals.Claude2100k, 30) + } + } + + return false +} + +// CanEnableModelWithSubscription returns (canEnable, usePlan) +func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string) (bool, bool) { + // use subscription quota first + if user != nil && HandleSubscriptionUsage(db, cache, user, model) { + return true, true + } + return CanEnableModel(db, user, model), false +} diff --git a/auth/subscription.go b/auth/subscription.go index 653d6c7..9772eac 100644 --- a/auth/subscription.go +++ b/auth/subscription.go @@ -31,19 +31,12 @@ func BuySubscription(db *sql.DB, user *User, month int) bool { return false } -func IncreaseSubscriptionUsage(cache *redis.Client, user *User) bool { - return utils.IncrWithLimit(cache, globals.GetGPT4LimitFormat(user.ID), 1, 50, 60*60*24) // 1 day +func IncreaseSubscriptionUsage(cache *redis.Client, user *User, t string, limit int64) bool { + return utils.IncrWithLimit(cache, globals.GetSubscriptionLimitFormat(t, user.ID), 1, limit, 60*60*24) // 1 day } -func DecreaseSubscriptionUsage(cache *redis.Client, user *User) bool { - return utils.DecrInt(cache, globals.GetGPT4LimitFormat(user.ID), 1) -} - -func CanEnableSubscription(db *sql.DB, cache *redis.Client, user *User) bool { - if user == nil { - return false - } - return user.IsSubscribe(db) && IncreaseSubscriptionUsage(cache, user) +func DecreaseSubscriptionUsage(cache *redis.Client, user *User, t string) bool { + return utils.DecrInt(cache, globals.GetSubscriptionLimitFormat(t, user.ID), 1) } func GetDalleUsageLimit(db *sql.DB, user *User) int { diff --git a/auth/user.go b/auth/user.go index 0b00723..e5dbd2c 100644 --- a/auth/user.go +++ b/auth/user.go @@ -170,7 +170,7 @@ type Usage struct { } func (u *User) GetSubscriptionUsage(db *sql.DB, cache *redis.Client) Usage { - gpt4, _ := utils.GetInt(cache, globals.GetGPT4LimitFormat(u.GetID(db))) + gpt4, _ := utils.GetInt(cache, globals.GetSubscriptionLimitFormat(globals.GPT4, u.GetID(db))) dalle, _ := utils.GetInt(cache, globals.GetImageLimitFormat(u.GetID(db))) return Usage{ diff --git a/globals/usage.go b/globals/usage.go index 20618c7..519325b 100644 --- a/globals/usage.go +++ b/globals/usage.go @@ -9,6 +9,6 @@ func GetImageLimitFormat(id int64) string { return fmt.Sprintf(":imagelimit:%s:%d", time.Now().Format("2006-01-02"), id) } -func GetGPT4LimitFormat(id int64) string { - return fmt.Sprintf(":subscription-usage:%s:%d", time.Now().Format("2006-01-02"), id) +func GetSubscriptionLimitFormat(t string, id int64) string { + return fmt.Sprintf(":subscription-usage-%s:%s:%d", t, time.Now().Format("2006-01-02"), id) } diff --git a/manager/chat.go b/manager/chat.go index a3c63b8..76f4f47 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -90,9 +90,9 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve model := instance.GetModel() db := conn.GetDB() cache := conn.GetCache() - reversible := globals.IsGPT4NativeModel(model) && auth.CanEnableSubscription(db, cache, user) + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) - if !auth.CanEnableModelWithSubscription(db, user, model, reversible) { + if !check { conn.Send(globals.ChatSegmentResponse{ Message: defaultQuotaMessage, Quota: 0, @@ -104,7 +104,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve if form := ExtractCacheData(conn.GetCtx(), &CacheProps{ Message: segment, Model: model, - Reversible: reversible, + Reversible: plan, }); form != nil { MockStreamSender(conn, form.Message) return form.Message @@ -114,7 +114,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve err := adapter.NewChatRequest(&adapter.ChatProps{ Model: model, Message: segment, - Reversible: reversible && globals.IsGPT4Model(model), + Reversible: plan, }, func(data string) error { if signal := conn.PeekWithType(StopType); signal != nil { // stop signal from client @@ -130,7 +130,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve if err != nil && err.Error() != "signal" { globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) - CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible) + CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), plan) conn.Send(globals.ChatSegmentResponse{ Message: err.Error(), Quota: GetErrorQuota(model), @@ -139,7 +139,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve return err.Error() } - CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible) + CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), plan) if buffer.IsEmpty() { conn.Send(globals.ChatSegmentResponse{ @@ -158,7 +158,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve SaveCacheData(conn.GetCtx(), &CacheProps{ Message: segment, Model: model, - Reversible: reversible, + Reversible: plan, }, &CacheData{ Keyword: keyword, Message: result, diff --git a/manager/completions.go b/manager/completions.go index 50820be..e5f10dd 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -23,16 +23,16 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] db := utils.GetDBFromContext(c) cache := utils.GetCacheFromContext(c) - reversible := globals.IsGPT4NativeModel(model) && auth.CanEnableSubscription(db, cache, user) + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) - if !auth.CanEnableModelWithSubscription(db, user, model, reversible) { + if !check { return keyword, defaultQuotaMessage, 0 } if form := ExtractCacheData(c, &CacheProps{ Message: segment, Model: model, - Reversible: reversible, + Reversible: plan, }); form != nil { return form.Keyword, form.Message, 0 } @@ -40,22 +40,22 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] buffer := utils.NewBuffer(model, segment) if err := adapter.NewChatRequest(&adapter.ChatProps{ Model: model, - Reversible: reversible && globals.IsGPT4Model(model), + Reversible: plan, Message: segment, }, func(resp string) error { buffer.Write(resp) return nil }); err != nil { - CollectQuota(c, user, buffer.GetQuota(), reversible) + CollectQuota(c, user, buffer.GetQuota(), plan) return keyword, err.Error(), GetErrorQuota(model) } - CollectQuota(c, user, buffer.GetQuota(), reversible) + CollectQuota(c, user, buffer.GetQuota(), plan) SaveCacheData(c, &CacheProps{ Message: segment, Model: model, - Reversible: reversible, + Reversible: plan, }, &CacheData{ Keyword: keyword, Message: buffer.ReadWithDefault(defaultMessage), diff --git a/manager/transhipment.go b/manager/transhipment.go index 0ed1e50..26cba57 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -93,9 +93,8 @@ func TranshipmentAPI(c *gin.Context) { id := utils.Md5Encrypt(username + form.Model + time.Now().String()) created := time.Now().Unix() - reversible := globals.IsGPT4NativeModel(form.Model) && auth.CanEnableSubscription(db, cache, user) - - if !auth.CanEnableModelWithSubscription(db, user, form.Model, reversible) { + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model) + if !check { c.JSON(http.StatusForbidden, gin.H{ "status": false, "error": "quota exceeded", @@ -105,18 +104,18 @@ func TranshipmentAPI(c *gin.Context) { } if form.Stream { - sendStreamTranshipmentResponse(c, form, id, created, user, reversible) + sendStreamTranshipmentResponse(c, form, id, created, user, plan) } else { - sendTranshipmentResponse(c, form, id, created, user, reversible) + sendTranshipmentResponse(c, form, id, created, user, plan) } } -func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, reversible bool) { +func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, plan bool) { buffer := utils.NewBuffer(form.Model, form.Messages) err := adapter.NewChatRequest(&adapter.ChatProps{ Model: form.Model, Message: form.Messages, - Reversible: reversible && globals.IsGPT4Model(form.Model), + Reversible: plan, Token: form.MaxTokens, }, func(data string) error { buffer.Write(data) @@ -126,7 +125,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, globals.Warn(fmt.Sprintf("error from chat request api: %s", err.Error())) } - CollectQuota(c, user, buffer.GetQuota(), reversible) + CollectQuota(c, user, buffer.GetQuota(), plan) c.JSON(http.StatusOK, TranshipmentResponse{ Id: id, Object: "chat.completion", @@ -172,7 +171,7 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, } } -func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, reversible bool) { +func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, plan bool) { channel := make(chan TranshipmentStreamResponse) go func() { @@ -180,20 +179,20 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st if err := adapter.NewChatRequest(&adapter.ChatProps{ Model: form.Model, Message: form.Messages, - Reversible: reversible && globals.IsGPT4Model(form.Model), + Reversible: plan, Token: form.MaxTokens, }, func(data string) error { channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false) return nil }); err != nil { channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) - CollectQuota(c, user, buffer.GetQuota(), reversible) + CollectQuota(c, user, buffer.GetQuota(), plan) close(channel) return } channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true) - CollectQuota(c, user, buffer.GetQuota(), reversible) + CollectQuota(c, user, buffer.GetQuota(), plan) close(channel) return }() diff --git a/screenshot/landspace.png b/screenshot/landspace.png index 04b10e0..cbee28f 100644 Binary files a/screenshot/landspace.png and b/screenshot/landspace.png differ