diff --git a/auth/auth.go b/auth/auth.go index c8784bb..cd9c45d 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "chat/channel" "chat/globals" "chat/utils" "database/sql" @@ -8,7 +9,9 @@ import ( "fmt" "github.com/dgrijalva/jwt-go" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "github.com/spf13/viper" + "strings" "time" ) @@ -54,9 +57,129 @@ func ParseApiKey(c *gin.Context, key string) *User { return &user } +func getCode(c *gin.Context, cache *redis.Client, email string) string { + code, err := cache.Get(c, fmt.Sprintf("nio:otp:%s", email)).Result() + if err != nil { + return "" + } + return code +} + +func checkCode(c *gin.Context, cache *redis.Client, email, code string) bool { + storage := getCode(c, cache, email) + if len(storage) == 0 { + return false + } + + if storage != code { + return false + } + + cache.Del(c, fmt.Sprintf("nio:top:%s", email)) + return true +} + +func setCode(c *gin.Context, cache *redis.Client, email, code string) { + cache.Set(c, fmt.Sprintf("nio:otp:%s", email), code, 5*time.Minute) +} + +func generateCode(c *gin.Context, cache *redis.Client, email string) string { + code := utils.GenerateCode(6) + setCode(c, cache, email, code) + return code +} + +func Verify(c *gin.Context, email string) error { + cache := utils.GetCacheFromContext(c) + code := generateCode(c, cache, email) + + provider := channel.SystemInstance.GetMail() + return provider.SendMail( + email, + "Chat Nio | OTP Verification", + fmt.Sprintf("Your OTP code is: %s", code), + ) +} + +func SignUp(c *gin.Context, form RegisterForm) (string, error) { + db := utils.GetDBFromContext(c) + cache := utils.GetCacheFromContext(c) + + username := strings.TrimSpace(form.Username) + password := strings.TrimSpace(form.Password) + email := strings.TrimSpace(form.Email) + code := strings.TrimSpace(form.Code) + + if !utils.All( + validateUsername(username), + validatePassword(password), + validateEmail(email), + validateCode(code), + ) { + return "", errors.New("invalid username/password/email format") + } + + if !IsUserExist(db, username) { + return "", fmt.Errorf("username is already taken, please try another one username (your current username: %s)", username) + } + + if !IsEmailExist(db, email) { + return "", fmt.Errorf("email is already taken, please try another one email (your current email: %s)", email) + } + + if !checkCode(c, cache, email, code) { + return "", errors.New("invalid email verification code") + } + + hash := utils.Sha2Encrypt(password) + + user := &User{ + Username: username, + Password: hash, + Email: email, + BindID: getMaxBindId(db) + 1, + Token: utils.Sha2Encrypt(email + username), + } + + if _, err := db.Exec(` + INSERT INTO auth (username, password, email, bind_id, token) + VALUES (?, ?, ?, ?, ?) + `, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil { + return "", err + } + + return user.GenerateToken() +} + +func Login(c *gin.Context, form LoginForm) (string, error) { + db := utils.GetDBFromContext(c) + username := strings.TrimSpace(form.Username) + password := strings.TrimSpace(form.Password) + + if !utils.All( + validateUsernameOrEmail(username), + validatePassword(password), + ) { + return "", errors.New("invalid username or password format") + } + + hash := utils.Sha2Encrypt(password) + + // get user from db by username (or email) and password + var user User + if err := db.QueryRow(` + SELECT auth.id, auth.username, auth.password FROM auth + WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ? + `, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil { + return "", errors.New("invalid username or password") + } + + return user.GenerateToken() +} + func DeepLogin(c *gin.Context, token string) (string, error) { if !useDeeptrain() { - return "", errors.New("deeptrain feature is disabled") + return "", errors.New("deeptrain mode is disabled") } user := Validate(token) @@ -91,6 +214,41 @@ func DeepLogin(c *gin.Context, token string) (string, error) { return u.GenerateToken() } +func Reset(c *gin.Context, form ResetForm) error { + db := utils.GetDBFromContext(c) + cache := utils.GetCacheFromContext(c) + + email := strings.TrimSpace(form.Email) + code := strings.TrimSpace(form.Code) + password := strings.TrimSpace(form.Password) + + if !utils.All( + validateEmail(email), + validateCode(code), + validatePassword(password), + ) { + return errors.New("invalid email/code/password format") + } + + if !IsEmailExist(db, email) { + return errors.New("email is not registered") + } + + if !checkCode(c, cache, email, code) { + return errors.New("invalid email verification code") + } + + hash := utils.Sha2Encrypt(password) + + if _, err := db.Exec(` + UPDATE auth SET password = ? WHERE email = ? + `, hash, email); err != nil { + return err + } + + return nil +} + func (u *User) Validate(c *gin.Context) bool { if u.Username == "" || u.Password == "" { return false diff --git a/auth/controller.go b/auth/controller.go index 8bca9c6..fa7e7cb 100644 --- a/auth/controller.go +++ b/auth/controller.go @@ -7,10 +7,32 @@ import ( "strings" ) +type RegisterForm struct { + Username string `form:"username" binding:"required"` + Password string `form:"password" binding:"required"` + Email string `form:"email" binding:"required"` + Code string `form:"code" binding:"required"` +} + +type VerifyForm struct { + Email string `form:"email" binding:"required"` +} + +type LoginForm struct { + Username string `form:"username" binding:"required"` + Password string `form:"password" binding:"required"` +} + type DeepLoginForm struct { Token string `form:"token" binding:"required"` } +type ResetForm struct { + Email string `form:"email" binding:"required"` + Code string `form:"code" binding:"required"` + Password string `form:"password" binding:"required"` +} + type BuyForm struct { Quota int `json:"quota" binding:"required"` } @@ -111,8 +133,16 @@ func RequireEnterprise(c *gin.Context) *User { return user } -func LoginAPI(c *gin.Context) { - var form DeepLoginForm +func RegisterAPI(c *gin.Context) { + if useDeeptrain() { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": "this api is not available for deeptrain mode", + }) + return + } + + var form RegisterForm if err := c.ShouldBind(&form); err != nil { c.JSON(http.StatusOK, gin.H{ "status": false, @@ -121,7 +151,7 @@ func LoginAPI(c *gin.Context) { return } - token, err := DeepLogin(c, form.Token) + token, err := SignUp(c, form) if err != nil { c.JSON(http.StatusOK, gin.H{ "status": false, @@ -136,6 +166,94 @@ func LoginAPI(c *gin.Context) { }) } +func LoginAPI(c *gin.Context) { + var token string + var err error + + if useDeeptrain() { + var form DeepLoginForm + if err := c.ShouldBind(&form); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": "bad request", + }) + return + } + + token, err = DeepLogin(c, form.Token) + } else { + var form LoginForm + if err := c.ShouldBind(&form); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": "bad request", + }) + return + } + + token, err = Login(c, form) + } + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": true, + "token": token, + }) +} + +func VerifyAPI(c *gin.Context) { + var form VerifyForm + if err := c.ShouldBind(&form); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": "bad request", + }) + return + } + + if err := Verify(c, form.Email); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": true, + }) +} + +func ResetAPI(c *gin.Context) { + var form ResetForm + if err := c.ShouldBind(&form); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": "bad request", + }) + return + } + + if err := Reset(c, form); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "error": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": true, + }) +} + func StateAPI(c *gin.Context) { username := utils.GetUserFromContext(c) c.JSON(http.StatusOK, gin.H{ diff --git a/auth/router.go b/auth/router.go index b73277b..b3ceb06 100644 --- a/auth/router.go +++ b/auth/router.go @@ -3,6 +3,9 @@ package auth import "github.com/gin-gonic/gin" func Register(app *gin.Engine) { + app.POST("/verify", VerifyAPI) + app.POST("/reset", ResetAPI) + app.POST("/register", RegisterAPI) app.POST("/login", LoginAPI) app.POST("/state", StateAPI) app.GET("/apikey", KeyAPI) diff --git a/auth/struct.go b/auth/struct.go index 250ba46..ec21ced 100644 --- a/auth/struct.go +++ b/auth/struct.go @@ -9,6 +9,7 @@ import ( type User struct { ID int64 `json:"id"` Username string `json:"username"` + Email string `json:"email"` BindID int64 `json:"bind_id"` Password string `json:"password"` Token string `json:"token"` @@ -56,6 +57,20 @@ func (u *User) GetID(db *sql.DB) int64 { return u.ID } +func (u *User) GetEmail(db *sql.DB) string { + if len(u.Email) > 0 { + return u.Email + } + + var email sql.NullString + if err := db.QueryRow("SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil { + return "" + } + + u.Email = email.String + return u.Email +} + func IsUserExist(db *sql.DB, username string) bool { var count int if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil { @@ -64,6 +79,22 @@ func IsUserExist(db *sql.DB, username string) bool { return count > 0 } +func IsEmailExist(db *sql.DB, email string) bool { + var count int + if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil { + return false + } + return count > 0 +} + +func getMaxBindId(db *sql.DB) int64 { + var max int64 + if err := db.QueryRow("SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil { + return 0 + } + return max +} + func GetGroup(db *sql.DB, user *User) string { if user == nil { return globals.AnonymousType diff --git a/auth/validators.go b/auth/validators.go new file mode 100644 index 0000000..e1c6cd5 --- /dev/null +++ b/auth/validators.go @@ -0,0 +1,36 @@ +package auth + +import ( + "regexp" + "strings" +) + +func isInRange(content string, min, max int) bool { + content = strings.TrimSpace(content) + return len(content) >= min && len(content) <= max +} + +func validateUsername(username string) bool { + return isInRange(username, 2, 24) +} + +func validateUsernameOrEmail(username string) bool { + return isInRange(username, 1, 255) +} + +func validatePassword(password string) bool { + return isInRange(password, 6, 36) +} + +func validateEmail(email string) bool { + if !isInRange(email, 1, 255) { + return false + } + + exp := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) + return exp.MatchString(email) +} + +func validateCode(code string) bool { + return isInRange(code, 1, 64) +} diff --git a/migration/3.8.sql b/migration/3.8.sql new file mode 100644 index 0000000..de7b478 --- /dev/null +++ b/migration/3.8.sql @@ -0,0 +1,3 @@ +ALTER TABLE auth + ADD COLUMN email VARCHAR(255) UNIQUE, + ADD COLUMN is_banned BOOLEAN DEFAULT FALSE; diff --git a/utils/base.go b/utils/base.go index e177b23..00eb6f7 100644 --- a/utils/base.go +++ b/utils/base.go @@ -254,3 +254,21 @@ func GetIndexSafe[T any](arr []T, index int) *T { } return &arr[index] } + +func All(arr ...bool) bool { + for _, v := range arr { + if !v { + return false + } + } + return true +} + +func Any(arr ...bool) bool { + for _, v := range arr { + if v { + return true + } + } + return false +} diff --git a/utils/smtp.go b/utils/smtp.go index 95178cb..e368c81 100644 --- a/utils/smtp.go +++ b/utils/smtp.go @@ -25,19 +25,17 @@ func NewSmtpPoster(host string, port int, username string, password string, from } } -func (s *SmtpPoster) SendMail(to string, subject string, body string) { +func (s *SmtpPoster) SendMail(to string, subject string, body string) error { addr := fmt.Sprintf("%s:%d", s.Host, s.Port) auth := smtp.PlainAuth("", s.From, s.Password, s.Host) - err := smtpRequestWithTLS(addr, auth, s.From, []string{to}, + + return smtpRequestWithTLS(addr, auth, s.From, []string{to}, []byte(formatMail(map[string]string{ "From": fmt.Sprintf("%s <%s>", s.Username, s.From), "To": to, "Subject": subject, "Content-Type": "text/html; charset=utf-8", }, body))) - if err != nil { - return - } } func dial(addr string) (*smtp.Client, error) {