update: all in one feature

This commit is contained in:
Zhang Minghan 2023-12-23 23:02:51 +08:00
parent aa627fb61d
commit abff5b9821
8 changed files with 374 additions and 9 deletions

View File

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"chat/channel"
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"database/sql" "database/sql"
@ -8,7 +9,9 @@ import (
"fmt" "fmt"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/spf13/viper" "github.com/spf13/viper"
"strings"
"time" "time"
) )
@ -54,9 +57,129 @@ func ParseApiKey(c *gin.Context, key string) *User {
return &user return &user
} }
func getCode(c *gin.Context, cache *redis.Client, email string) string {
code, err := cache.Get(c, fmt.Sprintf("nio:otp:%s", email)).Result()
if err != nil {
return ""
}
return code
}
func checkCode(c *gin.Context, cache *redis.Client, email, code string) bool {
storage := getCode(c, cache, email)
if len(storage) == 0 {
return false
}
if storage != code {
return false
}
cache.Del(c, fmt.Sprintf("nio:top:%s", email))
return true
}
func setCode(c *gin.Context, cache *redis.Client, email, code string) {
cache.Set(c, fmt.Sprintf("nio:otp:%s", email), code, 5*time.Minute)
}
func generateCode(c *gin.Context, cache *redis.Client, email string) string {
code := utils.GenerateCode(6)
setCode(c, cache, email, code)
return code
}
func Verify(c *gin.Context, email string) error {
cache := utils.GetCacheFromContext(c)
code := generateCode(c, cache, email)
provider := channel.SystemInstance.GetMail()
return provider.SendMail(
email,
"Chat Nio | OTP Verification",
fmt.Sprintf("Your OTP code is: %s", code),
)
}
func SignUp(c *gin.Context, form RegisterForm) (string, error) {
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
username := strings.TrimSpace(form.Username)
password := strings.TrimSpace(form.Password)
email := strings.TrimSpace(form.Email)
code := strings.TrimSpace(form.Code)
if !utils.All(
validateUsername(username),
validatePassword(password),
validateEmail(email),
validateCode(code),
) {
return "", errors.New("invalid username/password/email format")
}
if !IsUserExist(db, username) {
return "", fmt.Errorf("username is already taken, please try another one username (your current username: %s)", username)
}
if !IsEmailExist(db, email) {
return "", fmt.Errorf("email is already taken, please try another one email (your current email: %s)", email)
}
if !checkCode(c, cache, email, code) {
return "", errors.New("invalid email verification code")
}
hash := utils.Sha2Encrypt(password)
user := &User{
Username: username,
Password: hash,
Email: email,
BindID: getMaxBindId(db) + 1,
Token: utils.Sha2Encrypt(email + username),
}
if _, err := db.Exec(`
INSERT INTO auth (username, password, email, bind_id, token)
VALUES (?, ?, ?, ?, ?)
`, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil {
return "", err
}
return user.GenerateToken()
}
func Login(c *gin.Context, form LoginForm) (string, error) {
db := utils.GetDBFromContext(c)
username := strings.TrimSpace(form.Username)
password := strings.TrimSpace(form.Password)
if !utils.All(
validateUsernameOrEmail(username),
validatePassword(password),
) {
return "", errors.New("invalid username or password format")
}
hash := utils.Sha2Encrypt(password)
// get user from db by username (or email) and password
var user User
if err := db.QueryRow(`
SELECT auth.id, auth.username, auth.password FROM auth
WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ?
`, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil {
return "", errors.New("invalid username or password")
}
return user.GenerateToken()
}
func DeepLogin(c *gin.Context, token string) (string, error) { func DeepLogin(c *gin.Context, token string) (string, error) {
if !useDeeptrain() { if !useDeeptrain() {
return "", errors.New("deeptrain feature is disabled") return "", errors.New("deeptrain mode is disabled")
} }
user := Validate(token) user := Validate(token)
@ -91,6 +214,41 @@ func DeepLogin(c *gin.Context, token string) (string, error) {
return u.GenerateToken() return u.GenerateToken()
} }
func Reset(c *gin.Context, form ResetForm) error {
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
email := strings.TrimSpace(form.Email)
code := strings.TrimSpace(form.Code)
password := strings.TrimSpace(form.Password)
if !utils.All(
validateEmail(email),
validateCode(code),
validatePassword(password),
) {
return errors.New("invalid email/code/password format")
}
if !IsEmailExist(db, email) {
return errors.New("email is not registered")
}
if !checkCode(c, cache, email, code) {
return errors.New("invalid email verification code")
}
hash := utils.Sha2Encrypt(password)
if _, err := db.Exec(`
UPDATE auth SET password = ? WHERE email = ?
`, hash, email); err != nil {
return err
}
return nil
}
func (u *User) Validate(c *gin.Context) bool { func (u *User) Validate(c *gin.Context) bool {
if u.Username == "" || u.Password == "" { if u.Username == "" || u.Password == "" {
return false return false

View File

@ -7,10 +7,32 @@ import (
"strings" "strings"
) )
type RegisterForm struct {
Username string `form:"username" binding:"required"`
Password string `form:"password" binding:"required"`
Email string `form:"email" binding:"required"`
Code string `form:"code" binding:"required"`
}
type VerifyForm struct {
Email string `form:"email" binding:"required"`
}
type LoginForm struct {
Username string `form:"username" binding:"required"`
Password string `form:"password" binding:"required"`
}
type DeepLoginForm struct { type DeepLoginForm struct {
Token string `form:"token" binding:"required"` Token string `form:"token" binding:"required"`
} }
type ResetForm struct {
Email string `form:"email" binding:"required"`
Code string `form:"code" binding:"required"`
Password string `form:"password" binding:"required"`
}
type BuyForm struct { type BuyForm struct {
Quota int `json:"quota" binding:"required"` Quota int `json:"quota" binding:"required"`
} }
@ -111,8 +133,16 @@ func RequireEnterprise(c *gin.Context) *User {
return user return user
} }
func LoginAPI(c *gin.Context) { func RegisterAPI(c *gin.Context) {
var form DeepLoginForm if useDeeptrain() {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": "this api is not available for deeptrain mode",
})
return
}
var form RegisterForm
if err := c.ShouldBind(&form); err != nil { if err := c.ShouldBind(&form); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"status": false, "status": false,
@ -121,7 +151,7 @@ func LoginAPI(c *gin.Context) {
return return
} }
token, err := DeepLogin(c, form.Token) token, err := SignUp(c, form)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"status": false, "status": false,
@ -136,6 +166,94 @@ func LoginAPI(c *gin.Context) {
}) })
} }
func LoginAPI(c *gin.Context) {
var token string
var err error
if useDeeptrain() {
var form DeepLoginForm
if err := c.ShouldBind(&form); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": "bad request",
})
return
}
token, err = DeepLogin(c, form.Token)
} else {
var form LoginForm
if err := c.ShouldBind(&form); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": "bad request",
})
return
}
token, err = Login(c, form)
}
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"status": true,
"token": token,
})
}
func VerifyAPI(c *gin.Context) {
var form VerifyForm
if err := c.ShouldBind(&form); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": "bad request",
})
return
}
if err := Verify(c, form.Email); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"status": true,
})
}
func ResetAPI(c *gin.Context) {
var form ResetForm
if err := c.ShouldBind(&form); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": "bad request",
})
return
}
if err := Reset(c, form); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"status": true,
})
}
func StateAPI(c *gin.Context) { func StateAPI(c *gin.Context) {
username := utils.GetUserFromContext(c) username := utils.GetUserFromContext(c)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -3,6 +3,9 @@ package auth
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"
func Register(app *gin.Engine) { func Register(app *gin.Engine) {
app.POST("/verify", VerifyAPI)
app.POST("/reset", ResetAPI)
app.POST("/register", RegisterAPI)
app.POST("/login", LoginAPI) app.POST("/login", LoginAPI)
app.POST("/state", StateAPI) app.POST("/state", StateAPI)
app.GET("/apikey", KeyAPI) app.GET("/apikey", KeyAPI)

View File

@ -9,6 +9,7 @@ import (
type User struct { type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Username string `json:"username"` Username string `json:"username"`
Email string `json:"email"`
BindID int64 `json:"bind_id"` BindID int64 `json:"bind_id"`
Password string `json:"password"` Password string `json:"password"`
Token string `json:"token"` Token string `json:"token"`
@ -56,6 +57,20 @@ func (u *User) GetID(db *sql.DB) int64 {
return u.ID return u.ID
} }
func (u *User) GetEmail(db *sql.DB) string {
if len(u.Email) > 0 {
return u.Email
}
var email sql.NullString
if err := db.QueryRow("SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil {
return ""
}
u.Email = email.String
return u.Email
}
func IsUserExist(db *sql.DB, username string) bool { func IsUserExist(db *sql.DB, username string) bool {
var count int var count int
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil { if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
@ -64,6 +79,22 @@ func IsUserExist(db *sql.DB, username string) bool {
return count > 0 return count > 0
} }
func IsEmailExist(db *sql.DB, email string) bool {
var count int
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil {
return false
}
return count > 0
}
func getMaxBindId(db *sql.DB) int64 {
var max int64
if err := db.QueryRow("SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil {
return 0
}
return max
}
func GetGroup(db *sql.DB, user *User) string { func GetGroup(db *sql.DB, user *User) string {
if user == nil { if user == nil {
return globals.AnonymousType return globals.AnonymousType

36
auth/validators.go Normal file
View File

@ -0,0 +1,36 @@
package auth
import (
"regexp"
"strings"
)
func isInRange(content string, min, max int) bool {
content = strings.TrimSpace(content)
return len(content) >= min && len(content) <= max
}
func validateUsername(username string) bool {
return isInRange(username, 2, 24)
}
func validateUsernameOrEmail(username string) bool {
return isInRange(username, 1, 255)
}
func validatePassword(password string) bool {
return isInRange(password, 6, 36)
}
func validateEmail(email string) bool {
if !isInRange(email, 1, 255) {
return false
}
exp := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
return exp.MatchString(email)
}
func validateCode(code string) bool {
return isInRange(code, 1, 64)
}

3
migration/3.8.sql Normal file
View File

@ -0,0 +1,3 @@
ALTER TABLE auth
ADD COLUMN email VARCHAR(255) UNIQUE,
ADD COLUMN is_banned BOOLEAN DEFAULT FALSE;

View File

@ -254,3 +254,21 @@ func GetIndexSafe[T any](arr []T, index int) *T {
} }
return &arr[index] return &arr[index]
} }
func All(arr ...bool) bool {
for _, v := range arr {
if !v {
return false
}
}
return true
}
func Any(arr ...bool) bool {
for _, v := range arr {
if v {
return true
}
}
return false
}

View File

@ -25,19 +25,17 @@ func NewSmtpPoster(host string, port int, username string, password string, from
} }
} }
func (s *SmtpPoster) SendMail(to string, subject string, body string) { func (s *SmtpPoster) SendMail(to string, subject string, body string) error {
addr := fmt.Sprintf("%s:%d", s.Host, s.Port) addr := fmt.Sprintf("%s:%d", s.Host, s.Port)
auth := smtp.PlainAuth("", s.From, s.Password, s.Host) auth := smtp.PlainAuth("", s.From, s.Password, s.Host)
err := smtpRequestWithTLS(addr, auth, s.From, []string{to},
return smtpRequestWithTLS(addr, auth, s.From, []string{to},
[]byte(formatMail(map[string]string{ []byte(formatMail(map[string]string{
"From": fmt.Sprintf("%s <%s>", s.Username, s.From), "From": fmt.Sprintf("%s <%s>", s.Username, s.From),
"To": to, "To": to,
"Subject": subject, "Subject": subject,
"Content-Type": "text/html; charset=utf-8", "Content-Type": "text/html; charset=utf-8",
}, body))) }, body)))
if err != nil {
return
}
} }
func dial(addr string) (*smtp.Client, error) { func dial(addr string) (*smtp.Client, error) {