mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
add admin dashboard
This commit is contained in:
parent
b8076e12a9
commit
85094836a6
@ -2,6 +2,7 @@ package generation
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/admin"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
@ -15,7 +16,7 @@ func CreateGeneration(model string, prompt string, path string, plan bool, hook
|
||||
message := GenerateMessage(prompt)
|
||||
buffer := utils.NewBuffer(model, message)
|
||||
|
||||
if err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Message: message,
|
||||
Plan: plan,
|
||||
@ -24,7 +25,10 @@ func CreateGeneration(model string, prompt string, path string, plan bool, hook
|
||||
buffer.Write(data)
|
||||
hook(buffer, data)
|
||||
return nil
|
||||
}); err != nil {
|
||||
})
|
||||
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
93
admin/analysis.go
Normal file
93
admin/analysis.go
Normal file
@ -0,0 +1,93 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getDates(t []time.Time) []string {
|
||||
return utils.Each[time.Time, string](t, func(date time.Time) string {
|
||||
return date.Format("1/2")
|
||||
})
|
||||
}
|
||||
|
||||
func getFormat(t time.Time) string {
|
||||
return t.Format("2006-01-02")
|
||||
}
|
||||
|
||||
func GetSubscriptionUsers(db *sql.DB) int64 {
|
||||
var count int64
|
||||
err := db.QueryRow(`
|
||||
SELECT COUNT(*) FROM subscription WHERE expired_at > NOW()
|
||||
`).Scan(&count)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func GetBillingToday(cache *redis.Client) int64 {
|
||||
return utils.MustInt(cache, getBillingFormat(getDay()))
|
||||
}
|
||||
|
||||
func GetBillingMonth(cache *redis.Client) int64 {
|
||||
return utils.MustInt(cache, getMonthBillingFormat(getMonth()))
|
||||
}
|
||||
|
||||
func GetModelData(cache *redis.Client) ModelChartForm {
|
||||
dates := getDays(7)
|
||||
|
||||
return ModelChartForm{
|
||||
Date: getDates(dates),
|
||||
Value: utils.EachNotNil[string, ModelData](globals.AllModels, func(model string) *ModelData {
|
||||
data := ModelData{
|
||||
Model: model,
|
||||
Data: utils.Each[time.Time, int64](dates, func(date time.Time) int64 {
|
||||
return utils.MustInt(cache, getModelFormat(getFormat(date), model))
|
||||
}),
|
||||
}
|
||||
if utils.Sum(data.Data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &data
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func GetRequestData(cache *redis.Client) RequestChartForm {
|
||||
dates := getDays(7)
|
||||
|
||||
return RequestChartForm{
|
||||
Date: getDates(dates),
|
||||
Value: utils.Each[time.Time, int64](dates, func(date time.Time) int64 {
|
||||
return utils.MustInt(cache, getRequestFormat(getFormat(date)))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func GetBillingData(cache *redis.Client) BillingChartForm {
|
||||
dates := getDays(30)
|
||||
|
||||
return BillingChartForm{
|
||||
Date: getDates(dates),
|
||||
Value: utils.Each[time.Time, float32](dates, func(date time.Time) float32 {
|
||||
return float32(utils.MustInt(cache, getBillingFormat(getFormat(date)))) / 100.
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func GetErrorData(cache *redis.Client) ErrorChartForm {
|
||||
dates := getDays(7)
|
||||
|
||||
return ErrorChartForm{
|
||||
Date: getDates(dates),
|
||||
Value: utils.Each[time.Time, int64](dates, func(date time.Time) int64 {
|
||||
return utils.MustInt(cache, getErrorFormat(getFormat(date)))
|
||||
}),
|
||||
}
|
||||
}
|
38
admin/controller.go
Normal file
38
admin/controller.go
Normal file
@ -0,0 +1,38 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func InfoAPI(c *gin.Context) {
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
c.JSON(http.StatusOK, InfoForm{
|
||||
SubscriptionCount: GetSubscriptionUsers(db),
|
||||
BillingToday: GetBillingToday(cache),
|
||||
BillingMonth: GetBillingMonth(cache),
|
||||
})
|
||||
}
|
||||
|
||||
func ModelAnalysisAPI(c *gin.Context) {
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
c.JSON(http.StatusOK, GetModelData(cache))
|
||||
}
|
||||
|
||||
func RequestAnalysisAPI(c *gin.Context) {
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
c.JSON(http.StatusOK, GetRequestData(cache))
|
||||
}
|
||||
|
||||
func BillingAnalysisAPI(c *gin.Context) {
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
c.JSON(http.StatusOK, GetBillingData(cache))
|
||||
}
|
||||
|
||||
func ErrorAnalysisAPI(c *gin.Context) {
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
c.JSON(http.StatusOK, GetErrorData(cache))
|
||||
}
|
46
admin/format.go
Normal file
46
admin/format.go
Normal file
@ -0,0 +1,46 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getMonth() string {
|
||||
date := time.Now()
|
||||
return date.Format("2006-01")
|
||||
}
|
||||
|
||||
func getDay() string {
|
||||
date := time.Now()
|
||||
return date.Format("2006-01-02")
|
||||
}
|
||||
|
||||
func getDays(n int) []time.Time {
|
||||
current := time.Now()
|
||||
var days []time.Time
|
||||
for i := n; i > 0; i-- {
|
||||
days = append(days, current.AddDate(0, 0, -i+1))
|
||||
}
|
||||
|
||||
return days
|
||||
}
|
||||
|
||||
func getErrorFormat(t string) string {
|
||||
return fmt.Sprintf("nio:err-analysis-%s", t)
|
||||
}
|
||||
|
||||
func getBillingFormat(t string) string {
|
||||
return fmt.Sprintf("nio:billing-analysis-%s", t)
|
||||
}
|
||||
|
||||
func getMonthBillingFormat(t string) string {
|
||||
return fmt.Sprintf("nio:billing-analysis-%s", t)
|
||||
}
|
||||
|
||||
func getRequestFormat(t string) string {
|
||||
return fmt.Sprintf("nio:request-analysis-%s", t)
|
||||
}
|
||||
|
||||
func getModelFormat(t string, model string) string {
|
||||
return fmt.Sprintf("nio:model-analysis-%s-%s", model, t)
|
||||
}
|
11
admin/router.go
Normal file
11
admin/router.go
Normal file
@ -0,0 +1,11 @@
|
||||
package admin
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Register(app *gin.Engine) {
|
||||
app.GET("/admin/analytics/info", InfoAPI)
|
||||
app.GET("/admin/analytics/model", ModelAnalysisAPI)
|
||||
app.GET("/admin/analytics/request", RequestAnalysisAPI)
|
||||
app.GET("/admin/analytics/billing", BillingAnalysisAPI)
|
||||
app.GET("/admin/analytics/error", ErrorAnalysisAPI)
|
||||
}
|
37
admin/statistic.go
Normal file
37
admin/statistic.go
Normal file
@ -0,0 +1,37 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"chat/connection"
|
||||
"chat/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
)
|
||||
|
||||
func IncrErrorRequest(cache *redis.Client) {
|
||||
utils.IncrOnce(cache, getErrorFormat(getDay()), time.Hour*24*7*2)
|
||||
}
|
||||
|
||||
func IncrBillingRequest(cache *redis.Client, amount int64) {
|
||||
utils.IncrWithExpire(cache, getBillingFormat(getDay()), amount, time.Hour*24*30*2)
|
||||
utils.IncrWithExpire(cache, getMonthBillingFormat(getMonth()), amount, time.Hour*24*30*2)
|
||||
}
|
||||
|
||||
func IncrRequest(cache *redis.Client) {
|
||||
utils.IncrOnce(cache, getRequestFormat(getDay()), time.Hour*24*7*2)
|
||||
}
|
||||
|
||||
func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
|
||||
IncrRequest(cache)
|
||||
utils.IncrWithExpire(cache, getModelFormat(getDay(), model), tokens, time.Hour*24*7*2)
|
||||
}
|
||||
|
||||
func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
|
||||
instance := connection.Cache
|
||||
|
||||
if err != nil && err.Error() != "signal" {
|
||||
IncrErrorRequest(instance)
|
||||
return
|
||||
}
|
||||
|
||||
IncrModelRequest(instance, model, int64(buffer.CountToken()))
|
||||
}
|
32
admin/types.go
Normal file
32
admin/types.go
Normal file
@ -0,0 +1,32 @@
|
||||
package admin
|
||||
|
||||
type InfoForm struct {
|
||||
BillingToday int64 `json:"billing_today"`
|
||||
BillingMonth int64 `json:"billing_month"`
|
||||
SubscriptionCount int64 `json:"subscription_count"`
|
||||
}
|
||||
|
||||
type ModelData struct {
|
||||
Model string `json:"model"`
|
||||
Data []int64 `json:"data"`
|
||||
}
|
||||
|
||||
type ModelChartForm struct {
|
||||
Date []string `json:"date"`
|
||||
Value []ModelData `json:"value"`
|
||||
}
|
||||
|
||||
type RequestChartForm struct {
|
||||
Date []string `json:"date"`
|
||||
Value []int64 `json:"value"`
|
||||
}
|
||||
|
||||
type BillingChartForm struct {
|
||||
Date []string `json:"date"`
|
||||
Value []float32 `json:"value"`
|
||||
}
|
||||
|
||||
type ErrorChartForm struct {
|
||||
Date []string `json:"date"`
|
||||
Value []int64 `json:"value"`
|
||||
}
|
@ -22,7 +22,12 @@ export async function getModelChart(): Promise<ModelChartResponse> {
|
||||
return { date: [], value: [] };
|
||||
}
|
||||
|
||||
return response.data as ModelChartResponse;
|
||||
const data = response.data as ModelChartResponse;
|
||||
|
||||
return {
|
||||
date: data.date,
|
||||
value: data.value || [],
|
||||
}
|
||||
}
|
||||
|
||||
export async function getRequestChart(): Promise<RequestChartResponse> {
|
||||
|
@ -8,7 +8,7 @@ import {
|
||||
} from "@/utils/env.ts";
|
||||
import { getMemory } from "@/utils/memory.ts";
|
||||
|
||||
export const version = "3.6.13rc1";
|
||||
export const version = "3.6.14";
|
||||
export const dev: boolean = getDev();
|
||||
export const deploy: boolean = true;
|
||||
export let rest_api: string = getRestApi(deploy);
|
||||
|
@ -12,7 +12,7 @@ function Admin() {
|
||||
|
||||
useEffect(() => {
|
||||
if (init && !admin) router.navigate("/");
|
||||
}, []);
|
||||
}, [init]);
|
||||
|
||||
return (
|
||||
<div className={`admin-page`}>
|
||||
|
@ -39,6 +39,25 @@ func RequireAuth(c *gin.Context) *User {
|
||||
return user
|
||||
}
|
||||
|
||||
func RequireAdmin(c *gin.Context) *User {
|
||||
user := RequireAuth(c)
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
if !user.IsAdmin(db) {
|
||||
c.JSON(200, gin.H{
|
||||
"status": false,
|
||||
"error": "admin required",
|
||||
})
|
||||
c.Abort()
|
||||
return nil
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func RequireSubscribe(c *gin.Context) *User {
|
||||
user := RequireAuth(c)
|
||||
if user == nil {
|
||||
@ -127,6 +146,7 @@ func SubscribeAPI(c *gin.Context) {
|
||||
}
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
var form SubscribeForm
|
||||
if err := c.ShouldBindJSON(&form); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@ -144,7 +164,7 @@ func SubscribeAPI(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if BuySubscription(db, user, form.Month) {
|
||||
if BuySubscription(db, cache, user, form.Month) {
|
||||
c.JSON(200, gin.H{
|
||||
"status": true,
|
||||
"error": "success",
|
||||
@ -164,6 +184,7 @@ func BuyAPI(c *gin.Context) {
|
||||
}
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
var form BuyForm
|
||||
if err := c.ShouldBindJSON(&form); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@ -181,7 +202,7 @@ func BuyAPI(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if BuyQuota(db, user, form.Quota) {
|
||||
if BuyQuota(db, cache, user, form.Quota) {
|
||||
c.JSON(200, gin.H{
|
||||
"status": true,
|
||||
"error": "success",
|
||||
|
@ -1,9 +1,11 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"chat/admin"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@ -64,9 +66,17 @@ func Pay(username string, amount float32) bool {
|
||||
return resp.Type
|
||||
}
|
||||
|
||||
func BuyQuota(db *sql.DB, user *User, quota int) bool {
|
||||
func (u *User) Pay(cache *redis.Client, amount float32) bool {
|
||||
state := Pay(u.Username, amount)
|
||||
if state {
|
||||
admin.IncrBillingRequest(cache, int64(amount*100))
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func BuyQuota(db *sql.DB, cache *redis.Client, user *User, quota int) bool {
|
||||
money := float32(quota) * 0.1
|
||||
if Pay(user.Username, money) {
|
||||
if user.Pay(cache, money) {
|
||||
user.IncreaseQuota(db, float32(quota))
|
||||
return true
|
||||
}
|
||||
|
@ -19,12 +19,12 @@ func CountSubscriptionPrize(month int) float32 {
|
||||
return base
|
||||
}
|
||||
|
||||
func BuySubscription(db *sql.DB, user *User, month int) bool {
|
||||
func BuySubscription(db *sql.DB, cache *redis.Client, user *User, month int) bool {
|
||||
if month < 1 || month > 999 {
|
||||
return false
|
||||
}
|
||||
money := CountSubscriptionPrize(month)
|
||||
if Pay(user.Username, money) {
|
||||
if user.Pay(cache, money) {
|
||||
user.AddSubscription(db, month)
|
||||
return true
|
||||
}
|
||||
|
16
auth/user.go
16
auth/user.go
@ -21,6 +21,7 @@ type User struct {
|
||||
BindID int64 `json:"bind_id"`
|
||||
Password string `json:"password"`
|
||||
Token string `json:"token"`
|
||||
Admin bool `json:"is_admin"`
|
||||
Subscription *time.Time `json:"subscription"`
|
||||
}
|
||||
|
||||
@ -82,6 +83,20 @@ func (u *User) GenerateToken() (string, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (u *User) IsAdmin(db *sql.DB) bool {
|
||||
if u.Admin {
|
||||
return true
|
||||
}
|
||||
|
||||
var admin sql.NullBool
|
||||
if err := db.QueryRow("SELECT is_admin FROM auth WHERE username = ?", u.Username).Scan(&admin); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
u.Admin = admin.Valid && admin.Bool
|
||||
return u.Admin
|
||||
}
|
||||
|
||||
func (u *User) GetID(db *sql.DB) int64 {
|
||||
if u.ID > 0 {
|
||||
return u.ID
|
||||
@ -345,6 +360,7 @@ func StateAPI(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": len(username) != 0,
|
||||
"user": username,
|
||||
"admin": utils.GetAdminFromContext(c),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,8 @@ func CreateUserTable(db *sql.DB) {
|
||||
bind_id INT UNIQUE,
|
||||
username VARCHAR(24) UNIQUE,
|
||||
token VARCHAR(255) NOT NULL,
|
||||
password VARCHAR(64) NOT NULL
|
||||
password VARCHAR(64) NOT NULL,
|
||||
is_admin BOOLEAN DEFAULT FALSE
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
|
2
main.go
2
main.go
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"chat/addition"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/cli"
|
||||
"chat/manager"
|
||||
@ -27,6 +28,7 @@ func main() {
|
||||
|
||||
{
|
||||
auth.Register(app)
|
||||
admin.Register(app)
|
||||
manager.Register(app)
|
||||
addition.Register(app)
|
||||
conversation.Register(app)
|
||||
|
@ -3,6 +3,7 @@ package manager
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/addition/web"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/manager/conversation"
|
||||
@ -103,6 +104,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
})
|
||||
})
|
||||
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil && err.Error() != "signal" {
|
||||
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||
|
||||
|
@ -3,6 +3,7 @@ package manager
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/addition/web"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
@ -38,14 +39,17 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
|
||||
}
|
||||
|
||||
buffer := utils.NewBuffer(model, segment)
|
||||
if err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
Model: model,
|
||||
Plan: plan,
|
||||
Message: segment,
|
||||
}, func(resp string) error {
|
||||
buffer.Write(resp)
|
||||
return nil
|
||||
}); err != nil {
|
||||
})
|
||||
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil {
|
||||
auth.RevertSubscriptionUsage(cache, user, model, plan)
|
||||
CollectQuota(c, user, buffer, plan)
|
||||
return keyword, err.Error(), GetErrorQuota(model)
|
||||
|
@ -2,6 +2,7 @@ package manager
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
@ -125,6 +126,8 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
|
||||
buffer.Write(data)
|
||||
return nil
|
||||
})
|
||||
|
||||
admin.AnalysisRequest(form.Model, buffer, err)
|
||||
if err != nil {
|
||||
globals.Warn(fmt.Sprintf("error from chat request api: %s", err.Error()))
|
||||
}
|
||||
@ -180,7 +183,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
|
||||
|
||||
go func() {
|
||||
buffer := utils.NewBuffer(form.Model, form.Messages)
|
||||
if err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
err := adapter.NewChatRequest(&adapter.ChatProps{
|
||||
Model: form.Model,
|
||||
Message: form.Messages,
|
||||
Plan: plan,
|
||||
@ -188,7 +191,10 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
|
||||
}, func(data string) error {
|
||||
channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
|
||||
return nil
|
||||
}); err != nil {
|
||||
})
|
||||
|
||||
admin.AnalysisRequest(form.Model, buffer, err)
|
||||
if err != nil {
|
||||
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
|
||||
CollectQuota(c, user, buffer, plan)
|
||||
close(channel)
|
||||
|
@ -8,18 +8,21 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ProcessToken(c *gin.Context, token string) bool {
|
||||
func ProcessToken(c *gin.Context, token string) *auth.User {
|
||||
if user := auth.ParseToken(c, token); user != nil {
|
||||
c.Set("auth", true)
|
||||
c.Set("user", user.Username)
|
||||
c.Set("agent", "token")
|
||||
c.Next()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return user
|
||||
}
|
||||
|
||||
func ProcessKey(c *gin.Context, key string) bool {
|
||||
c.Set("auth", false)
|
||||
c.Set("user", "")
|
||||
c.Set("agent", "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessKey(c *gin.Context, key string) *auth.User {
|
||||
addr := c.ClientIP()
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
@ -28,15 +31,14 @@ func ProcessKey(c *gin.Context, key string) bool {
|
||||
"code": 403,
|
||||
"message": "ip in black list",
|
||||
})
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
if user := auth.ParseApiKey(c, key); user != nil {
|
||||
c.Set("auth", true)
|
||||
c.Set("user", user.Username)
|
||||
c.Set("agent", "api")
|
||||
c.Next()
|
||||
return true
|
||||
return user
|
||||
}
|
||||
|
||||
utils.IncrIP(cache, addr)
|
||||
@ -44,31 +46,50 @@ func ProcessKey(c *gin.Context, key string) bool {
|
||||
"code": 401,
|
||||
"message": "Access denied. Please provide correct api key.",
|
||||
})
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
func ProcessAuthorization(c *gin.Context) *auth.User {
|
||||
k := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if k != "" {
|
||||
if strings.HasPrefix(k, "Bearer ") {
|
||||
k = strings.TrimPrefix(k, "Bearer ")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, "sk-") { // api agent
|
||||
if ProcessKey(c, k) {
|
||||
return
|
||||
}
|
||||
} else { // token agent
|
||||
if ProcessToken(c, k) {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(k, "sk-") {
|
||||
// api agent
|
||||
return ProcessKey(c, k)
|
||||
} else {
|
||||
// token agent
|
||||
return ProcessToken(c, k)
|
||||
}
|
||||
}
|
||||
|
||||
c.Set("auth", false)
|
||||
c.Set("user", "")
|
||||
c.Set("agent", "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
instance := ProcessAuthorization(c)
|
||||
|
||||
db := utils.GetDBFromContext(c)
|
||||
|
||||
admin := instance != nil && instance.IsAdmin(db)
|
||||
c.Set("admin", admin)
|
||||
if strings.HasPrefix(path, "/admin") {
|
||||
if !admin {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "Access denied.",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,14 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func Sum[T int | int64 | float32 | float64](arr []T) T {
|
||||
var res T
|
||||
for _, v := range arr {
|
||||
res += v
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func Contains[T comparable](value T, slice []T) bool {
|
||||
for _, item := range slice {
|
||||
if item == value {
|
||||
@ -123,3 +131,21 @@ func InsertChannel[T any](ch chan T, value T, index int) {
|
||||
ch <- v
|
||||
}
|
||||
}
|
||||
|
||||
func Each[T any, U any](arr []T, f func(T) U) []U {
|
||||
var res []U
|
||||
for _, v := range arr {
|
||||
res = append(res, f(v))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func EachNotNil[T any, U any](arr []T, f func(T) *U) []U {
|
||||
var res []U
|
||||
for _, v := range arr {
|
||||
if val := f(v); val != nil {
|
||||
res = append(res, *val)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
@ -65,6 +65,17 @@ func IncrIP(cache *redis.Client, ip string) int64 {
|
||||
return val
|
||||
}
|
||||
|
||||
func IncrWithExpire(cache *redis.Client, key string, delta int64, expiration time.Duration) {
|
||||
_, err := Incr(cache, key, delta)
|
||||
if err != nil && err == redis.Nil {
|
||||
cache.Set(context.Background(), key, delta, expiration)
|
||||
}
|
||||
}
|
||||
|
||||
func IncrOnce(cache *redis.Client, key string, expiration time.Duration) {
|
||||
IncrWithExpire(cache, key, 1, expiration)
|
||||
}
|
||||
|
||||
func IsInBlackList(cache *redis.Client, ip string) bool {
|
||||
val, err := GetInt(cache, fmt.Sprintf(":ip-rate:%s", ip))
|
||||
return err == nil && val > 50
|
||||
|
@ -18,6 +18,10 @@ func GetUserFromContext(c *gin.Context) string {
|
||||
return c.MustGet("user").(string)
|
||||
}
|
||||
|
||||
func GetAdminFromContext(c *gin.Context) bool {
|
||||
return c.MustGet("admin").(bool)
|
||||
}
|
||||
|
||||
func GetAgentFromContext(c *gin.Context) string {
|
||||
return c.MustGet("agent").(string)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user