mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 21:40:15 +09:00
359 lines
8.8 KiB
Go
359 lines
8.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"chat/globals"
|
|
"chat/utils"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/dgrijalva/jwt-go"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/spf13/viper"
|
|
"math"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
type User struct {
|
|
ID int64 `json:"id"`
|
|
Username string `json:"username"`
|
|
BindID int64 `json:"bind_id"`
|
|
Password string `json:"password"`
|
|
Token string `json:"token"`
|
|
Subscription *time.Time `json:"subscription"`
|
|
}
|
|
|
|
type LoginForm struct {
|
|
Token string `form:"token" binding:"required"`
|
|
}
|
|
|
|
func GetUser(c *gin.Context) *User {
|
|
if c.GetBool("auth") {
|
|
return &User{
|
|
Username: c.GetString("user"),
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GetId(db *sql.DB, user *User) int64 {
|
|
if user == nil {
|
|
return -1
|
|
}
|
|
return user.GetID(db)
|
|
}
|
|
|
|
func (u *User) Validate(c *gin.Context) bool {
|
|
if u.Username == "" || u.Password == "" {
|
|
return false
|
|
}
|
|
cache := utils.GetCacheFromContext(c)
|
|
|
|
if password, err := cache.Get(c, fmt.Sprintf("nio:user:%s", u.Username)).Result(); err == nil && len(password) > 0 {
|
|
return u.Password == password
|
|
}
|
|
|
|
db := utils.GetDBFromContext(c)
|
|
var count int
|
|
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ? AND password = ?", u.Username, u.Password).Scan(&count); err != nil || count == 0 {
|
|
return false
|
|
}
|
|
|
|
cache.Set(c, fmt.Sprintf("nio:user:%s", u.Username), u.Password, 30*time.Minute)
|
|
return true
|
|
}
|
|
|
|
func (u *User) GenerateToken() (string, error) {
|
|
instance := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"username": u.Username,
|
|
"password": u.Password,
|
|
"exp": time.Now().Add(time.Hour * 24 * 30).Unix(),
|
|
})
|
|
token, err := instance.SignedString([]byte(viper.GetString("secret")))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
func (u *User) GetID(db *sql.DB) int64 {
|
|
if u.ID > 0 {
|
|
return u.ID
|
|
}
|
|
if err := db.QueryRow("SELECT id FROM auth WHERE username = ?", u.Username).Scan(&u.ID); err != nil {
|
|
return 0
|
|
}
|
|
return u.ID
|
|
}
|
|
|
|
func (u *User) GetQuota(db *sql.DB) float32 {
|
|
var quota float32
|
|
if err := db.QueryRow("SELECT quota FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil {
|
|
return 0.
|
|
}
|
|
return quota
|
|
}
|
|
|
|
func (u *User) GetUsedQuota(db *sql.DB) float32 {
|
|
var quota float32
|
|
if err := db.QueryRow("SELECT used FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil {
|
|
return 0.
|
|
}
|
|
return quota
|
|
}
|
|
|
|
func (u *User) SetQuota(db *sql.DB, quota float32) bool {
|
|
_, err := db.Exec(`
|
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = ?
|
|
`, u.GetID(db), quota, 0., quota)
|
|
return err == nil
|
|
}
|
|
|
|
func (u *User) SetUsedQuota(db *sql.DB, used float32) bool {
|
|
_, err := db.Exec(`
|
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = ?
|
|
`, u.GetID(db), 0., used, used)
|
|
return err == nil
|
|
}
|
|
|
|
func (u *User) IncreaseQuota(db *sql.DB, quota float32) bool {
|
|
_, err := db.Exec(`
|
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota + ?
|
|
`, u.GetID(db), quota, 0., quota)
|
|
return err == nil
|
|
}
|
|
|
|
func (u *User) IncreaseUsedQuota(db *sql.DB, used float32) bool {
|
|
_, err := db.Exec(`
|
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = used + ?
|
|
`, u.GetID(db), 0., used, used)
|
|
return err == nil
|
|
}
|
|
|
|
func (u *User) DecreaseQuota(db *sql.DB, quota float32) bool {
|
|
_, err := db.Exec(`
|
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota - ?
|
|
`, u.GetID(db), quota, 0., quota)
|
|
return err == nil
|
|
}
|
|
|
|
func (u *User) UseQuota(db *sql.DB, quota float32) bool {
|
|
if quota == 0 {
|
|
return true
|
|
}
|
|
if !u.DecreaseQuota(db, quota) {
|
|
return false
|
|
}
|
|
return u.IncreaseUsedQuota(db, quota)
|
|
}
|
|
|
|
func (u *User) GetSubscription(db *sql.DB) time.Time {
|
|
if u.Subscription != nil && u.Subscription.Unix() > 0 {
|
|
return *u.Subscription
|
|
}
|
|
|
|
var expiredAt []uint8
|
|
if err := db.QueryRow("SELECT expired_at FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&expiredAt); err != nil {
|
|
return time.Unix(0, 0)
|
|
}
|
|
|
|
u.Subscription = utils.ConvertTime(expiredAt)
|
|
return *u.Subscription
|
|
}
|
|
|
|
func (u *User) IsSubscribe(db *sql.DB) bool {
|
|
return u.GetSubscription(db).Unix() > time.Now().Unix()
|
|
}
|
|
|
|
func (u *User) IsEnterprise(db *sql.DB) bool {
|
|
if !u.IsSubscribe(db) {
|
|
return false
|
|
}
|
|
|
|
var enterprise sql.NullBool
|
|
if err := db.QueryRow("SELECT enterprise FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&enterprise); err != nil {
|
|
return false
|
|
}
|
|
|
|
return enterprise.Valid && enterprise.Bool
|
|
}
|
|
|
|
func (u *User) GetSubscriptionExpiredDay(db *sql.DB) int {
|
|
stamp := u.GetSubscription(db).Sub(time.Now())
|
|
return int(math.Round(stamp.Hours() / 24))
|
|
}
|
|
|
|
type Usage struct {
|
|
GPT4 int64 `json:"gpt4"`
|
|
}
|
|
|
|
func (u *User) GetSubscriptionUsage(db *sql.DB, cache *redis.Client) Usage {
|
|
return Usage{
|
|
GPT4: utils.MustInt(cache, globals.GetSubscriptionLimitFormat(globals.GPT4, u.GetID(db))),
|
|
}
|
|
}
|
|
|
|
func (u *User) AddSubscription(db *sql.DB, month int) bool {
|
|
current := u.GetSubscription(db)
|
|
if current.Unix() < time.Now().Unix() {
|
|
current = time.Now()
|
|
}
|
|
expiredAt := current.AddDate(0, month, 0)
|
|
_, err := db.Exec(`
|
|
INSERT INTO subscription (user_id, expired_at, total_month) VALUES (?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE expired_at = ?, total_month = total_month + ?
|
|
`, u.GetID(db), utils.ConvertSqlTime(expiredAt), month, utils.ConvertSqlTime(expiredAt), month)
|
|
return err == nil
|
|
}
|
|
|
|
func (u *User) CreateApiKey(db *sql.DB) string {
|
|
salt := utils.Sha2Encrypt(fmt.Sprintf("%s-%s", u.Username, utils.GenerateChar(utils.GetRandomInt(720, 1024))))
|
|
key := fmt.Sprintf("sk-%s", salt[:64]) // 64 bytes
|
|
if _, err := db.Exec("INSERT INTO apikey (user_id, api_key) VALUES (?, ?)", u.GetID(db), key); err != nil {
|
|
return ""
|
|
}
|
|
return key
|
|
}
|
|
|
|
func (u *User) GetApiKey(db *sql.DB) string {
|
|
var key string
|
|
if err := db.QueryRow("SELECT api_key FROM apikey WHERE user_id = ?", u.GetID(db)).Scan(&key); err != nil {
|
|
return u.CreateApiKey(db)
|
|
}
|
|
return key
|
|
}
|
|
|
|
func IsUserExist(db *sql.DB, username string) bool {
|
|
var count int
|
|
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
|
|
return false
|
|
}
|
|
return count > 0
|
|
}
|
|
|
|
func ParseToken(c *gin.Context, token string) *User {
|
|
instance, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(viper.GetString("secret")), nil
|
|
})
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
if claims, ok := instance.Claims.(jwt.MapClaims); ok && instance.Valid {
|
|
if int64(claims["exp"].(float64)) < time.Now().Unix() {
|
|
return nil
|
|
}
|
|
user := &User{
|
|
Username: claims["username"].(string),
|
|
Password: claims["password"].(string),
|
|
}
|
|
if !user.Validate(c) {
|
|
return nil
|
|
}
|
|
return user
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ParseApiKey(c *gin.Context, key string) *User {
|
|
db := utils.GetDBFromContext(c)
|
|
|
|
if len(key) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var user User
|
|
if err := db.QueryRow(`
|
|
SELECT auth.id, auth.username, auth.password FROM auth
|
|
INNER JOIN apikey ON auth.id = apikey.user_id
|
|
WHERE apikey.api_key = ?
|
|
`, key).Scan(&user.ID, &user.Username, &user.Password); err != nil {
|
|
return nil
|
|
}
|
|
|
|
return &user
|
|
}
|
|
|
|
func Login(c *gin.Context, token string) (string, error) {
|
|
// DeepTrain Token Validation
|
|
user := Validate(token)
|
|
if user == nil {
|
|
return "", errors.New("cannot validate access token")
|
|
}
|
|
|
|
db := utils.GetDBFromContext(c)
|
|
if !IsUserExist(db, user.Username) {
|
|
// register
|
|
password := utils.GenerateChar(64)
|
|
_ = db.QueryRow("INSERT INTO auth (bind_id, username, token, password) VALUES (?, ?, ?, ?)",
|
|
user.ID, user.Username, token, password)
|
|
u := &User{
|
|
Username: user.Username,
|
|
Password: password,
|
|
}
|
|
return u.GenerateToken()
|
|
}
|
|
|
|
// login
|
|
_ = db.QueryRow("UPDATE auth SET token = ? WHERE username = ?", token, user.Username)
|
|
var password string
|
|
err := db.QueryRow("SELECT password FROM auth WHERE username = ?", user.Username).Scan(&password)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
u := &User{
|
|
Username: user.Username,
|
|
Password: password,
|
|
}
|
|
return u.GenerateToken()
|
|
}
|
|
|
|
func LoginAPI(c *gin.Context) {
|
|
var form LoginForm
|
|
if err := c.ShouldBind(&form); err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"status": false,
|
|
"error": "bad request",
|
|
})
|
|
return
|
|
}
|
|
|
|
token, err := Login(c, form.Token)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"status": false,
|
|
"error": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"status": true,
|
|
"token": token,
|
|
})
|
|
}
|
|
|
|
func StateAPI(c *gin.Context) {
|
|
username := utils.GetUserFromContext(c)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"status": len(username) != 0,
|
|
"user": username,
|
|
})
|
|
}
|
|
|
|
func KeyAPI(c *gin.Context) {
|
|
user := GetUser(c)
|
|
if user == nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"status": false,
|
|
"error": "user not found",
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"status": true,
|
|
"key": user.GetApiKey(utils.GetDBFromContext(c)),
|
|
})
|
|
}
|