diff --git a/addition/generation/prompt.go b/addition/generation/prompt.go index e469682..95fd471 100644 --- a/addition/generation/prompt.go +++ b/addition/generation/prompt.go @@ -2,6 +2,7 @@ package generation import ( "chat/adapter" + "chat/admin" "chat/globals" "chat/utils" "fmt" @@ -15,7 +16,7 @@ func CreateGeneration(model string, prompt string, path string, plan bool, hook message := GenerateMessage(prompt) buffer := utils.NewBuffer(model, message) - if err := adapter.NewChatRequest(&adapter.ChatProps{ + err := adapter.NewChatRequest(&adapter.ChatProps{ Model: model, Message: message, Plan: plan, @@ -24,7 +25,10 @@ func CreateGeneration(model string, prompt string, path string, plan bool, hook buffer.Write(data) hook(buffer, data) return nil - }); err != nil { + }) + + admin.AnalysisRequest(model, buffer, err) + if err != nil { return err } diff --git a/admin/analysis.go b/admin/analysis.go new file mode 100644 index 0000000..b79a9d7 --- /dev/null +++ b/admin/analysis.go @@ -0,0 +1,93 @@ +package admin + +import ( + "chat/globals" + "chat/utils" + "database/sql" + "github.com/go-redis/redis/v8" + "time" +) + +func getDates(t []time.Time) []string { + return utils.Each[time.Time, string](t, func(date time.Time) string { + return date.Format("1/2") + }) +} + +func getFormat(t time.Time) string { + return t.Format("2006-01-02") +} + +func GetSubscriptionUsers(db *sql.DB) int64 { + var count int64 + err := db.QueryRow(` + SELECT COUNT(*) FROM subscription WHERE expired_at > NOW() + `).Scan(&count) + if err != nil { + return 0 + } + + return count +} + +func GetBillingToday(cache *redis.Client) int64 { + return utils.MustInt(cache, getBillingFormat(getDay())) +} + +func GetBillingMonth(cache *redis.Client) int64 { + return utils.MustInt(cache, getMonthBillingFormat(getMonth())) +} + +func GetModelData(cache *redis.Client) ModelChartForm { + dates := getDays(7) + + return ModelChartForm{ + Date: getDates(dates), + Value: utils.EachNotNil[string, ModelData](globals.AllModels, func(model string) *ModelData { + data := ModelData{ + Model: model, + Data: utils.Each[time.Time, int64](dates, func(date time.Time) int64 { + return utils.MustInt(cache, getModelFormat(getFormat(date), model)) + }), + } + if utils.Sum(data.Data) == 0 { + return nil + } + + return &data + }), + } +} + +func GetRequestData(cache *redis.Client) RequestChartForm { + dates := getDays(7) + + return RequestChartForm{ + Date: getDates(dates), + Value: utils.Each[time.Time, int64](dates, func(date time.Time) int64 { + return utils.MustInt(cache, getRequestFormat(getFormat(date))) + }), + } +} + +func GetBillingData(cache *redis.Client) BillingChartForm { + dates := getDays(30) + + return BillingChartForm{ + Date: getDates(dates), + Value: utils.Each[time.Time, float32](dates, func(date time.Time) float32 { + return float32(utils.MustInt(cache, getBillingFormat(getFormat(date)))) / 100. + }), + } +} + +func GetErrorData(cache *redis.Client) ErrorChartForm { + dates := getDays(7) + + return ErrorChartForm{ + Date: getDates(dates), + Value: utils.Each[time.Time, int64](dates, func(date time.Time) int64 { + return utils.MustInt(cache, getErrorFormat(getFormat(date))) + }), + } +} diff --git a/admin/controller.go b/admin/controller.go new file mode 100644 index 0000000..6c18c1d --- /dev/null +++ b/admin/controller.go @@ -0,0 +1,38 @@ +package admin + +import ( + "chat/utils" + "github.com/gin-gonic/gin" + "net/http" +) + +func InfoAPI(c *gin.Context) { + db := utils.GetDBFromContext(c) + cache := utils.GetCacheFromContext(c) + + c.JSON(http.StatusOK, InfoForm{ + SubscriptionCount: GetSubscriptionUsers(db), + BillingToday: GetBillingToday(cache), + BillingMonth: GetBillingMonth(cache), + }) +} + +func ModelAnalysisAPI(c *gin.Context) { + cache := utils.GetCacheFromContext(c) + c.JSON(http.StatusOK, GetModelData(cache)) +} + +func RequestAnalysisAPI(c *gin.Context) { + cache := utils.GetCacheFromContext(c) + c.JSON(http.StatusOK, GetRequestData(cache)) +} + +func BillingAnalysisAPI(c *gin.Context) { + cache := utils.GetCacheFromContext(c) + c.JSON(http.StatusOK, GetBillingData(cache)) +} + +func ErrorAnalysisAPI(c *gin.Context) { + cache := utils.GetCacheFromContext(c) + c.JSON(http.StatusOK, GetErrorData(cache)) +} diff --git a/admin/format.go b/admin/format.go new file mode 100644 index 0000000..daa04a8 --- /dev/null +++ b/admin/format.go @@ -0,0 +1,46 @@ +package admin + +import ( + "fmt" + "time" +) + +func getMonth() string { + date := time.Now() + return date.Format("2006-01") +} + +func getDay() string { + date := time.Now() + return date.Format("2006-01-02") +} + +func getDays(n int) []time.Time { + current := time.Now() + var days []time.Time + for i := n; i > 0; i-- { + days = append(days, current.AddDate(0, 0, -i+1)) + } + + return days +} + +func getErrorFormat(t string) string { + return fmt.Sprintf("nio:err-analysis-%s", t) +} + +func getBillingFormat(t string) string { + return fmt.Sprintf("nio:billing-analysis-%s", t) +} + +func getMonthBillingFormat(t string) string { + return fmt.Sprintf("nio:billing-analysis-%s", t) +} + +func getRequestFormat(t string) string { + return fmt.Sprintf("nio:request-analysis-%s", t) +} + +func getModelFormat(t string, model string) string { + return fmt.Sprintf("nio:model-analysis-%s-%s", model, t) +} diff --git a/admin/router.go b/admin/router.go new file mode 100644 index 0000000..01078ad --- /dev/null +++ b/admin/router.go @@ -0,0 +1,11 @@ +package admin + +import "github.com/gin-gonic/gin" + +func Register(app *gin.Engine) { + app.GET("/admin/analytics/info", InfoAPI) + app.GET("/admin/analytics/model", ModelAnalysisAPI) + app.GET("/admin/analytics/request", RequestAnalysisAPI) + app.GET("/admin/analytics/billing", BillingAnalysisAPI) + app.GET("/admin/analytics/error", ErrorAnalysisAPI) +} diff --git a/admin/statistic.go b/admin/statistic.go new file mode 100644 index 0000000..7794e24 --- /dev/null +++ b/admin/statistic.go @@ -0,0 +1,37 @@ +package admin + +import ( + "chat/connection" + "chat/utils" + "github.com/go-redis/redis/v8" + "time" +) + +func IncrErrorRequest(cache *redis.Client) { + utils.IncrOnce(cache, getErrorFormat(getDay()), time.Hour*24*7*2) +} + +func IncrBillingRequest(cache *redis.Client, amount int64) { + utils.IncrWithExpire(cache, getBillingFormat(getDay()), amount, time.Hour*24*30*2) + utils.IncrWithExpire(cache, getMonthBillingFormat(getMonth()), amount, time.Hour*24*30*2) +} + +func IncrRequest(cache *redis.Client) { + utils.IncrOnce(cache, getRequestFormat(getDay()), time.Hour*24*7*2) +} + +func IncrModelRequest(cache *redis.Client, model string, tokens int64) { + IncrRequest(cache) + utils.IncrWithExpire(cache, getModelFormat(getDay(), model), tokens, time.Hour*24*7*2) +} + +func AnalysisRequest(model string, buffer *utils.Buffer, err error) { + instance := connection.Cache + + if err != nil && err.Error() != "signal" { + IncrErrorRequest(instance) + return + } + + IncrModelRequest(instance, model, int64(buffer.CountToken())) +} diff --git a/admin/types.go b/admin/types.go new file mode 100644 index 0000000..80ae9a8 --- /dev/null +++ b/admin/types.go @@ -0,0 +1,32 @@ +package admin + +type InfoForm struct { + BillingToday int64 `json:"billing_today"` + BillingMonth int64 `json:"billing_month"` + SubscriptionCount int64 `json:"subscription_count"` +} + +type ModelData struct { + Model string `json:"model"` + Data []int64 `json:"data"` +} + +type ModelChartForm struct { + Date []string `json:"date"` + Value []ModelData `json:"value"` +} + +type RequestChartForm struct { + Date []string `json:"date"` + Value []int64 `json:"value"` +} + +type BillingChartForm struct { + Date []string `json:"date"` + Value []float32 `json:"value"` +} + +type ErrorChartForm struct { + Date []string `json:"date"` + Value []int64 `json:"value"` +} diff --git a/app/src/admin/api.ts b/app/src/admin/api.ts index a1817ce..58627b0 100644 --- a/app/src/admin/api.ts +++ b/app/src/admin/api.ts @@ -22,7 +22,12 @@ export async function getModelChart(): Promise { return { date: [], value: [] }; } - return response.data as ModelChartResponse; + const data = response.data as ModelChartResponse; + + return { + date: data.date, + value: data.value || [], + } } export async function getRequestChart(): Promise { diff --git a/app/src/conf.ts b/app/src/conf.ts index f831127..6f9054e 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -8,7 +8,7 @@ import { } from "@/utils/env.ts"; import { getMemory } from "@/utils/memory.ts"; -export const version = "3.6.13rc1"; +export const version = "3.6.14"; export const dev: boolean = getDev(); export const deploy: boolean = true; export let rest_api: string = getRestApi(deploy); diff --git a/app/src/routes/Admin.tsx b/app/src/routes/Admin.tsx index 7a85813..24bc237 100644 --- a/app/src/routes/Admin.tsx +++ b/app/src/routes/Admin.tsx @@ -12,7 +12,7 @@ function Admin() { useEffect(() => { if (init && !admin) router.navigate("/"); - }, []); + }, [init]); return (
diff --git a/auth/controller.go b/auth/controller.go index 73be752..a36384d 100644 --- a/auth/controller.go +++ b/auth/controller.go @@ -39,6 +39,25 @@ func RequireAuth(c *gin.Context) *User { return user } +func RequireAdmin(c *gin.Context) *User { + user := RequireAuth(c) + if user == nil { + return nil + } + + db := utils.GetDBFromContext(c) + if !user.IsAdmin(db) { + c.JSON(200, gin.H{ + "status": false, + "error": "admin required", + }) + c.Abort() + return nil + } + + return user +} + func RequireSubscribe(c *gin.Context) *User { user := RequireAuth(c) if user == nil { @@ -127,6 +146,7 @@ func SubscribeAPI(c *gin.Context) { } db := utils.GetDBFromContext(c) + cache := utils.GetCacheFromContext(c) var form SubscribeForm if err := c.ShouldBindJSON(&form); err != nil { c.JSON(200, gin.H{ @@ -144,7 +164,7 @@ func SubscribeAPI(c *gin.Context) { return } - if BuySubscription(db, user, form.Month) { + if BuySubscription(db, cache, user, form.Month) { c.JSON(200, gin.H{ "status": true, "error": "success", @@ -164,6 +184,7 @@ func BuyAPI(c *gin.Context) { } db := utils.GetDBFromContext(c) + cache := utils.GetCacheFromContext(c) var form BuyForm if err := c.ShouldBindJSON(&form); err != nil { c.JSON(200, gin.H{ @@ -181,7 +202,7 @@ func BuyAPI(c *gin.Context) { return } - if BuyQuota(db, user, form.Quota) { + if BuyQuota(db, cache, user, form.Quota) { c.JSON(200, gin.H{ "status": true, "error": "success", diff --git a/auth/payment.go b/auth/payment.go index 6738c5c..ec99b9f 100644 --- a/auth/payment.go +++ b/auth/payment.go @@ -1,9 +1,11 @@ package auth import ( + "chat/admin" "chat/utils" "database/sql" "encoding/json" + "github.com/go-redis/redis/v8" "github.com/spf13/viper" ) @@ -64,9 +66,17 @@ func Pay(username string, amount float32) bool { return resp.Type } -func BuyQuota(db *sql.DB, user *User, quota int) bool { +func (u *User) Pay(cache *redis.Client, amount float32) bool { + state := Pay(u.Username, amount) + if state { + admin.IncrBillingRequest(cache, int64(amount*100)) + } + return state +} + +func BuyQuota(db *sql.DB, cache *redis.Client, user *User, quota int) bool { money := float32(quota) * 0.1 - if Pay(user.Username, money) { + if user.Pay(cache, money) { user.IncreaseQuota(db, float32(quota)) return true } diff --git a/auth/subscription.go b/auth/subscription.go index 6f28eef..d12afde 100644 --- a/auth/subscription.go +++ b/auth/subscription.go @@ -19,12 +19,12 @@ func CountSubscriptionPrize(month int) float32 { return base } -func BuySubscription(db *sql.DB, user *User, month int) bool { +func BuySubscription(db *sql.DB, cache *redis.Client, user *User, month int) bool { if month < 1 || month > 999 { return false } money := CountSubscriptionPrize(month) - if Pay(user.Username, money) { + if user.Pay(cache, money) { user.AddSubscription(db, month) return true } diff --git a/auth/user.go b/auth/user.go index 2235d5c..a087fda 100644 --- a/auth/user.go +++ b/auth/user.go @@ -21,6 +21,7 @@ type User struct { BindID int64 `json:"bind_id"` Password string `json:"password"` Token string `json:"token"` + Admin bool `json:"is_admin"` Subscription *time.Time `json:"subscription"` } @@ -82,6 +83,20 @@ func (u *User) GenerateToken() (string, error) { return token, nil } +func (u *User) IsAdmin(db *sql.DB) bool { + if u.Admin { + return true + } + + var admin sql.NullBool + if err := db.QueryRow("SELECT is_admin FROM auth WHERE username = ?", u.Username).Scan(&admin); err != nil { + return false + } + + u.Admin = admin.Valid && admin.Bool + return u.Admin +} + func (u *User) GetID(db *sql.DB) int64 { if u.ID > 0 { return u.ID @@ -345,6 +360,7 @@ func StateAPI(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "status": len(username) != 0, "user": username, + "admin": utils.GetAdminFromContext(c), }) } diff --git a/connection/database.go b/connection/database.go index 2106296..1a6516e 100644 --- a/connection/database.go +++ b/connection/database.go @@ -56,7 +56,8 @@ func CreateUserTable(db *sql.DB) { bind_id INT UNIQUE, username VARCHAR(24) UNIQUE, token VARCHAR(255) NOT NULL, - password VARCHAR(64) NOT NULL + password VARCHAR(64) NOT NULL, + is_admin BOOLEAN DEFAULT FALSE ); `) if err != nil { diff --git a/main.go b/main.go index caaca27..b13a258 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "chat/addition" + "chat/admin" "chat/auth" "chat/cli" "chat/manager" @@ -27,6 +28,7 @@ func main() { { auth.Register(app) + admin.Register(app) manager.Register(app) addition.Register(app) conversation.Register(app) diff --git a/manager/chat.go b/manager/chat.go index 855cb94..6856535 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -3,6 +3,7 @@ package manager import ( "chat/adapter" "chat/addition/web" + "chat/admin" "chat/auth" "chat/globals" "chat/manager/conversation" @@ -103,6 +104,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve }) }) + admin.AnalysisRequest(model, buffer, err) if err != nil && err.Error() != "signal" { globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) diff --git a/manager/completions.go b/manager/completions.go index 3eb1480..653ad46 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -3,6 +3,7 @@ package manager import ( "chat/adapter" "chat/addition/web" + "chat/admin" "chat/auth" "chat/globals" "chat/utils" @@ -38,14 +39,17 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] } buffer := utils.NewBuffer(model, segment) - if err := adapter.NewChatRequest(&adapter.ChatProps{ + err := adapter.NewChatRequest(&adapter.ChatProps{ Model: model, Plan: plan, Message: segment, }, func(resp string) error { buffer.Write(resp) return nil - }); err != nil { + }) + + admin.AnalysisRequest(model, buffer, err) + if err != nil { auth.RevertSubscriptionUsage(cache, user, model, plan) CollectQuota(c, user, buffer, plan) return keyword, err.Error(), GetErrorQuota(model) diff --git a/manager/transhipment.go b/manager/transhipment.go index 70c892f..1342328 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -2,6 +2,7 @@ package manager import ( "chat/adapter" + "chat/admin" "chat/auth" "chat/globals" "chat/utils" @@ -125,6 +126,8 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, buffer.Write(data) return nil }) + + admin.AnalysisRequest(form.Model, buffer, err) if err != nil { globals.Warn(fmt.Sprintf("error from chat request api: %s", err.Error())) } @@ -180,7 +183,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st go func() { buffer := utils.NewBuffer(form.Model, form.Messages) - if err := adapter.NewChatRequest(&adapter.ChatProps{ + err := adapter.NewChatRequest(&adapter.ChatProps{ Model: form.Model, Message: form.Messages, Plan: plan, @@ -188,7 +191,10 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st }, func(data string) error { channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false) return nil - }); err != nil { + }) + + admin.AnalysisRequest(form.Model, buffer, err) + if err != nil { channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) CollectQuota(c, user, buffer, plan) close(channel) diff --git a/middleware/auth.go b/middleware/auth.go index 0677152..8c7f846 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -8,18 +8,21 @@ import ( "strings" ) -func ProcessToken(c *gin.Context, token string) bool { +func ProcessToken(c *gin.Context, token string) *auth.User { if user := auth.ParseToken(c, token); user != nil { c.Set("auth", true) c.Set("user", user.Username) c.Set("agent", "token") - c.Next() - return true + return user } - return false + + c.Set("auth", false) + c.Set("user", "") + c.Set("agent", "") + return nil } -func ProcessKey(c *gin.Context, key string) bool { +func ProcessKey(c *gin.Context, key string) *auth.User { addr := c.ClientIP() cache := utils.GetCacheFromContext(c) @@ -28,15 +31,14 @@ func ProcessKey(c *gin.Context, key string) bool { "code": 403, "message": "ip in black list", }) - return false + return nil } if user := auth.ParseApiKey(c, key); user != nil { c.Set("auth", true) c.Set("user", user.Username) c.Set("agent", "api") - c.Next() - return true + return user } utils.IncrIP(cache, addr) @@ -44,31 +46,50 @@ func ProcessKey(c *gin.Context, key string) bool { "code": 401, "message": "Access denied. Please provide correct api key.", }) - return false + return nil +} + +func ProcessAuthorization(c *gin.Context) *auth.User { + k := strings.TrimSpace(c.GetHeader("Authorization")) + if k != "" { + if strings.HasPrefix(k, "Bearer ") { + k = strings.TrimPrefix(k, "Bearer ") + } + + if strings.HasPrefix(k, "sk-") { + // api agent + return ProcessKey(c, k) + } else { + // token agent + return ProcessToken(c, k) + } + } + + c.Set("auth", false) + c.Set("user", "") + c.Set("agent", "") + return nil } func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - k := strings.TrimSpace(c.GetHeader("Authorization")) - if k != "" { - if strings.HasPrefix(k, "Bearer ") { - k = strings.TrimPrefix(k, "Bearer ") - } + path := c.Request.URL.Path + instance := ProcessAuthorization(c) - if strings.HasPrefix(k, "sk-") { // api agent - if ProcessKey(c, k) { - return - } - } else { // token agent - if ProcessToken(c, k) { - return - } + db := utils.GetDBFromContext(c) + + admin := instance != nil && instance.IsAdmin(db) + c.Set("admin", admin) + if strings.HasPrefix(path, "/admin") { + if !admin { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Access denied.", + }) + return } } - c.Set("auth", false) - c.Set("user", "") - c.Set("agent", "") c.Next() } } diff --git a/utils/base.go b/utils/base.go index 66bb452..12c8f42 100644 --- a/utils/base.go +++ b/utils/base.go @@ -5,6 +5,14 @@ import ( "fmt" ) +func Sum[T int | int64 | float32 | float64](arr []T) T { + var res T + for _, v := range arr { + res += v + } + return res +} + func Contains[T comparable](value T, slice []T) bool { for _, item := range slice { if item == value { @@ -123,3 +131,21 @@ func InsertChannel[T any](ch chan T, value T, index int) { ch <- v } } + +func Each[T any, U any](arr []T, f func(T) U) []U { + var res []U + for _, v := range arr { + res = append(res, f(v)) + } + return res +} + +func EachNotNil[T any, U any](arr []T, f func(T) *U) []U { + var res []U + for _, v := range arr { + if val := f(v); val != nil { + res = append(res, *val) + } + } + return res +} diff --git a/utils/cache.go b/utils/cache.go index 08d8879..018a873 100644 --- a/utils/cache.go +++ b/utils/cache.go @@ -65,6 +65,17 @@ func IncrIP(cache *redis.Client, ip string) int64 { return val } +func IncrWithExpire(cache *redis.Client, key string, delta int64, expiration time.Duration) { + _, err := Incr(cache, key, delta) + if err != nil && err == redis.Nil { + cache.Set(context.Background(), key, delta, expiration) + } +} + +func IncrOnce(cache *redis.Client, key string, expiration time.Duration) { + IncrWithExpire(cache, key, 1, expiration) +} + func IsInBlackList(cache *redis.Client, ip string) bool { val, err := GetInt(cache, fmt.Sprintf(":ip-rate:%s", ip)) return err == nil && val > 50 diff --git a/utils/ctx.go b/utils/ctx.go index 9507485..31627e4 100644 --- a/utils/ctx.go +++ b/utils/ctx.go @@ -18,6 +18,10 @@ func GetUserFromContext(c *gin.Context) string { return c.MustGet("user").(string) } +func GetAdminFromContext(c *gin.Context) bool { + return c.MustGet("admin").(bool) +} + func GetAgentFromContext(c *gin.Context) string { return c.MustGet("agent").(string) }