mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
update: all in one feature
This commit is contained in:
parent
aa627fb61d
commit
abff5b9821
160
auth/auth.go
160
auth/auth.go
@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/channel"
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
@ -8,7 +9,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -54,9 +57,129 @@ func ParseApiKey(c *gin.Context, key string) *User {
|
|||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCode(c *gin.Context, cache *redis.Client, email string) string {
|
||||||
|
code, err := cache.Get(c, fmt.Sprintf("nio:otp:%s", email)).Result()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkCode(c *gin.Context, cache *redis.Client, email, code string) bool {
|
||||||
|
storage := getCode(c, cache, email)
|
||||||
|
if len(storage) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if storage != code {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.Del(c, fmt.Sprintf("nio:top:%s", email))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCode(c *gin.Context, cache *redis.Client, email, code string) {
|
||||||
|
cache.Set(c, fmt.Sprintf("nio:otp:%s", email), code, 5*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateCode(c *gin.Context, cache *redis.Client, email string) string {
|
||||||
|
code := utils.GenerateCode(6)
|
||||||
|
setCode(c, cache, email, code)
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
func Verify(c *gin.Context, email string) error {
|
||||||
|
cache := utils.GetCacheFromContext(c)
|
||||||
|
code := generateCode(c, cache, email)
|
||||||
|
|
||||||
|
provider := channel.SystemInstance.GetMail()
|
||||||
|
return provider.SendMail(
|
||||||
|
email,
|
||||||
|
"Chat Nio | OTP Verification",
|
||||||
|
fmt.Sprintf("Your OTP code is: %s", code),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignUp(c *gin.Context, form RegisterForm) (string, error) {
|
||||||
|
db := utils.GetDBFromContext(c)
|
||||||
|
cache := utils.GetCacheFromContext(c)
|
||||||
|
|
||||||
|
username := strings.TrimSpace(form.Username)
|
||||||
|
password := strings.TrimSpace(form.Password)
|
||||||
|
email := strings.TrimSpace(form.Email)
|
||||||
|
code := strings.TrimSpace(form.Code)
|
||||||
|
|
||||||
|
if !utils.All(
|
||||||
|
validateUsername(username),
|
||||||
|
validatePassword(password),
|
||||||
|
validateEmail(email),
|
||||||
|
validateCode(code),
|
||||||
|
) {
|
||||||
|
return "", errors.New("invalid username/password/email format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsUserExist(db, username) {
|
||||||
|
return "", fmt.Errorf("username is already taken, please try another one username (your current username: %s)", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsEmailExist(db, email) {
|
||||||
|
return "", fmt.Errorf("email is already taken, please try another one email (your current email: %s)", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !checkCode(c, cache, email, code) {
|
||||||
|
return "", errors.New("invalid email verification code")
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := utils.Sha2Encrypt(password)
|
||||||
|
|
||||||
|
user := &User{
|
||||||
|
Username: username,
|
||||||
|
Password: hash,
|
||||||
|
Email: email,
|
||||||
|
BindID: getMaxBindId(db) + 1,
|
||||||
|
Token: utils.Sha2Encrypt(email + username),
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(`
|
||||||
|
INSERT INTO auth (username, password, email, bind_id, token)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
`, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user.GenerateToken()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Login(c *gin.Context, form LoginForm) (string, error) {
|
||||||
|
db := utils.GetDBFromContext(c)
|
||||||
|
username := strings.TrimSpace(form.Username)
|
||||||
|
password := strings.TrimSpace(form.Password)
|
||||||
|
|
||||||
|
if !utils.All(
|
||||||
|
validateUsernameOrEmail(username),
|
||||||
|
validatePassword(password),
|
||||||
|
) {
|
||||||
|
return "", errors.New("invalid username or password format")
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := utils.Sha2Encrypt(password)
|
||||||
|
|
||||||
|
// get user from db by username (or email) and password
|
||||||
|
var user User
|
||||||
|
if err := db.QueryRow(`
|
||||||
|
SELECT auth.id, auth.username, auth.password FROM auth
|
||||||
|
WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ?
|
||||||
|
`, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil {
|
||||||
|
return "", errors.New("invalid username or password")
|
||||||
|
}
|
||||||
|
|
||||||
|
return user.GenerateToken()
|
||||||
|
}
|
||||||
|
|
||||||
func DeepLogin(c *gin.Context, token string) (string, error) {
|
func DeepLogin(c *gin.Context, token string) (string, error) {
|
||||||
if !useDeeptrain() {
|
if !useDeeptrain() {
|
||||||
return "", errors.New("deeptrain feature is disabled")
|
return "", errors.New("deeptrain mode is disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
user := Validate(token)
|
user := Validate(token)
|
||||||
@ -91,6 +214,41 @@ func DeepLogin(c *gin.Context, token string) (string, error) {
|
|||||||
return u.GenerateToken()
|
return u.GenerateToken()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Reset(c *gin.Context, form ResetForm) error {
|
||||||
|
db := utils.GetDBFromContext(c)
|
||||||
|
cache := utils.GetCacheFromContext(c)
|
||||||
|
|
||||||
|
email := strings.TrimSpace(form.Email)
|
||||||
|
code := strings.TrimSpace(form.Code)
|
||||||
|
password := strings.TrimSpace(form.Password)
|
||||||
|
|
||||||
|
if !utils.All(
|
||||||
|
validateEmail(email),
|
||||||
|
validateCode(code),
|
||||||
|
validatePassword(password),
|
||||||
|
) {
|
||||||
|
return errors.New("invalid email/code/password format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsEmailExist(db, email) {
|
||||||
|
return errors.New("email is not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !checkCode(c, cache, email, code) {
|
||||||
|
return errors.New("invalid email verification code")
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := utils.Sha2Encrypt(password)
|
||||||
|
|
||||||
|
if _, err := db.Exec(`
|
||||||
|
UPDATE auth SET password = ? WHERE email = ?
|
||||||
|
`, hash, email); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *User) Validate(c *gin.Context) bool {
|
func (u *User) Validate(c *gin.Context) bool {
|
||||||
if u.Username == "" || u.Password == "" {
|
if u.Username == "" || u.Password == "" {
|
||||||
return false
|
return false
|
||||||
|
@ -7,10 +7,32 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RegisterForm struct {
|
||||||
|
Username string `form:"username" binding:"required"`
|
||||||
|
Password string `form:"password" binding:"required"`
|
||||||
|
Email string `form:"email" binding:"required"`
|
||||||
|
Code string `form:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VerifyForm struct {
|
||||||
|
Email string `form:"email" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginForm struct {
|
||||||
|
Username string `form:"username" binding:"required"`
|
||||||
|
Password string `form:"password" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
type DeepLoginForm struct {
|
type DeepLoginForm struct {
|
||||||
Token string `form:"token" binding:"required"`
|
Token string `form:"token" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ResetForm struct {
|
||||||
|
Email string `form:"email" binding:"required"`
|
||||||
|
Code string `form:"code" binding:"required"`
|
||||||
|
Password string `form:"password" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
type BuyForm struct {
|
type BuyForm struct {
|
||||||
Quota int `json:"quota" binding:"required"`
|
Quota int `json:"quota" binding:"required"`
|
||||||
}
|
}
|
||||||
@ -111,8 +133,16 @@ func RequireEnterprise(c *gin.Context) *User {
|
|||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoginAPI(c *gin.Context) {
|
func RegisterAPI(c *gin.Context) {
|
||||||
var form DeepLoginForm
|
if useDeeptrain() {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": "this api is not available for deeptrain mode",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var form RegisterForm
|
||||||
if err := c.ShouldBind(&form); err != nil {
|
if err := c.ShouldBind(&form); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"status": false,
|
"status": false,
|
||||||
@ -121,7 +151,7 @@ func LoginAPI(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := DeepLogin(c, form.Token)
|
token, err := SignUp(c, form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"status": false,
|
"status": false,
|
||||||
@ -136,6 +166,94 @@ func LoginAPI(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LoginAPI(c *gin.Context) {
|
||||||
|
var token string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if useDeeptrain() {
|
||||||
|
var form DeepLoginForm
|
||||||
|
if err := c.ShouldBind(&form); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": "bad request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err = DeepLogin(c, form.Token)
|
||||||
|
} else {
|
||||||
|
var form LoginForm
|
||||||
|
if err := c.ShouldBind(&form); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": "bad request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err = Login(c, form)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": true,
|
||||||
|
"token": token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyAPI(c *gin.Context) {
|
||||||
|
var form VerifyForm
|
||||||
|
if err := c.ShouldBind(&form); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": "bad request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := Verify(c, form.Email); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResetAPI(c *gin.Context) {
|
||||||
|
var form ResetForm
|
||||||
|
if err := c.ShouldBind(&form); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": "bad request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := Reset(c, form); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func StateAPI(c *gin.Context) {
|
func StateAPI(c *gin.Context) {
|
||||||
username := utils.GetUserFromContext(c)
|
username := utils.GetUserFromContext(c)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
@ -3,6 +3,9 @@ package auth
|
|||||||
import "github.com/gin-gonic/gin"
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
func Register(app *gin.Engine) {
|
func Register(app *gin.Engine) {
|
||||||
|
app.POST("/verify", VerifyAPI)
|
||||||
|
app.POST("/reset", ResetAPI)
|
||||||
|
app.POST("/register", RegisterAPI)
|
||||||
app.POST("/login", LoginAPI)
|
app.POST("/login", LoginAPI)
|
||||||
app.POST("/state", StateAPI)
|
app.POST("/state", StateAPI)
|
||||||
app.GET("/apikey", KeyAPI)
|
app.GET("/apikey", KeyAPI)
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
type User struct {
|
type User struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email"`
|
||||||
BindID int64 `json:"bind_id"`
|
BindID int64 `json:"bind_id"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
@ -56,6 +57,20 @@ func (u *User) GetID(db *sql.DB) int64 {
|
|||||||
return u.ID
|
return u.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *User) GetEmail(db *sql.DB) string {
|
||||||
|
if len(u.Email) > 0 {
|
||||||
|
return u.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
var email sql.NullString
|
||||||
|
if err := db.QueryRow("SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Email = email.String
|
||||||
|
return u.Email
|
||||||
|
}
|
||||||
|
|
||||||
func IsUserExist(db *sql.DB, username string) bool {
|
func IsUserExist(db *sql.DB, username string) bool {
|
||||||
var count int
|
var count int
|
||||||
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
|
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
|
||||||
@ -64,6 +79,22 @@ func IsUserExist(db *sql.DB, username string) bool {
|
|||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsEmailExist(db *sql.DB, email string) bool {
|
||||||
|
var count int
|
||||||
|
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMaxBindId(db *sql.DB) int64 {
|
||||||
|
var max int64
|
||||||
|
if err := db.QueryRow("SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
func GetGroup(db *sql.DB, user *User) string {
|
func GetGroup(db *sql.DB, user *User) string {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return globals.AnonymousType
|
return globals.AnonymousType
|
||||||
|
36
auth/validators.go
Normal file
36
auth/validators.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isInRange(content string, min, max int) bool {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
return len(content) >= min && len(content) <= max
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateUsername(username string) bool {
|
||||||
|
return isInRange(username, 2, 24)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateUsernameOrEmail(username string) bool {
|
||||||
|
return isInRange(username, 1, 255)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePassword(password string) bool {
|
||||||
|
return isInRange(password, 6, 36)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateEmail(email string) bool {
|
||||||
|
if !isInRange(email, 1, 255) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
exp := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||||
|
return exp.MatchString(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCode(code string) bool {
|
||||||
|
return isInRange(code, 1, 64)
|
||||||
|
}
|
3
migration/3.8.sql
Normal file
3
migration/3.8.sql
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
ALTER TABLE auth
|
||||||
|
ADD COLUMN email VARCHAR(255) UNIQUE,
|
||||||
|
ADD COLUMN is_banned BOOLEAN DEFAULT FALSE;
|
@ -254,3 +254,21 @@ func GetIndexSafe[T any](arr []T, index int) *T {
|
|||||||
}
|
}
|
||||||
return &arr[index]
|
return &arr[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func All(arr ...bool) bool {
|
||||||
|
for _, v := range arr {
|
||||||
|
if !v {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func Any(arr ...bool) bool {
|
||||||
|
for _, v := range arr {
|
||||||
|
if v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -25,19 +25,17 @@ func NewSmtpPoster(host string, port int, username string, password string, from
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SmtpPoster) SendMail(to string, subject string, body string) {
|
func (s *SmtpPoster) SendMail(to string, subject string, body string) error {
|
||||||
addr := fmt.Sprintf("%s:%d", s.Host, s.Port)
|
addr := fmt.Sprintf("%s:%d", s.Host, s.Port)
|
||||||
auth := smtp.PlainAuth("", s.From, s.Password, s.Host)
|
auth := smtp.PlainAuth("", s.From, s.Password, s.Host)
|
||||||
err := smtpRequestWithTLS(addr, auth, s.From, []string{to},
|
|
||||||
|
return smtpRequestWithTLS(addr, auth, s.From, []string{to},
|
||||||
[]byte(formatMail(map[string]string{
|
[]byte(formatMail(map[string]string{
|
||||||
"From": fmt.Sprintf("%s <%s>", s.Username, s.From),
|
"From": fmt.Sprintf("%s <%s>", s.Username, s.From),
|
||||||
"To": to,
|
"To": to,
|
||||||
"Subject": subject,
|
"Subject": subject,
|
||||||
"Content-Type": "text/html; charset=utf-8",
|
"Content-Type": "text/html; charset=utf-8",
|
||||||
}, body)))
|
}, body)))
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func dial(addr string) (*smtp.Client, error) {
|
func dial(addr string) (*smtp.Client, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user