mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
update: all in one feature
This commit is contained in:
parent
aa627fb61d
commit
abff5b9821
160
auth/auth.go
160
auth/auth.go
@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"chat/channel"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
@ -8,7 +9,9 @@ import (
|
||||
"fmt"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/spf13/viper"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -54,9 +57,129 @@ func ParseApiKey(c *gin.Context, key string) *User {
|
||||
return &user
|
||||
}
|
||||
|
||||
func getCode(c *gin.Context, cache *redis.Client, email string) string {
|
||||
code, err := cache.Get(c, fmt.Sprintf("nio:otp:%s", email)).Result()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return code
|
||||
}
|
||||
|
||||
func checkCode(c *gin.Context, cache *redis.Client, email, code string) bool {
|
||||
storage := getCode(c, cache, email)
|
||||
if len(storage) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if storage != code {
|
||||
return false
|
||||
}
|
||||
|
||||
cache.Del(c, fmt.Sprintf("nio:top:%s", email))
|
||||
return true
|
||||
}
|
||||
|
||||
func setCode(c *gin.Context, cache *redis.Client, email, code string) {
|
||||
cache.Set(c, fmt.Sprintf("nio:otp:%s", email), code, 5*time.Minute)
|
||||
}
|
||||
|
||||
func generateCode(c *gin.Context, cache *redis.Client, email string) string {
|
||||
code := utils.GenerateCode(6)
|
||||
setCode(c, cache, email, code)
|
||||
return code
|
||||
}
|
||||
|
||||
func Verify(c *gin.Context, email string) error {
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
code := generateCode(c, cache, email)
|
||||
|
||||
provider := channel.SystemInstance.GetMail()
|
||||
return provider.SendMail(
|
||||
email,
|
||||
"Chat Nio | OTP Verification",
|
||||
fmt.Sprintf("Your OTP code is: %s", code),
|
||||
)
|
||||
}
|
||||
|
||||
func SignUp(c *gin.Context, form RegisterForm) (string, error) {
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
username := strings.TrimSpace(form.Username)
|
||||
password := strings.TrimSpace(form.Password)
|
||||
email := strings.TrimSpace(form.Email)
|
||||
code := strings.TrimSpace(form.Code)
|
||||
|
||||
if !utils.All(
|
||||
validateUsername(username),
|
||||
validatePassword(password),
|
||||
validateEmail(email),
|
||||
validateCode(code),
|
||||
) {
|
||||
return "", errors.New("invalid username/password/email format")
|
||||
}
|
||||
|
||||
if !IsUserExist(db, username) {
|
||||
return "", fmt.Errorf("username is already taken, please try another one username (your current username: %s)", username)
|
||||
}
|
||||
|
||||
if !IsEmailExist(db, email) {
|
||||
return "", fmt.Errorf("email is already taken, please try another one email (your current email: %s)", email)
|
||||
}
|
||||
|
||||
if !checkCode(c, cache, email, code) {
|
||||
return "", errors.New("invalid email verification code")
|
||||
}
|
||||
|
||||
hash := utils.Sha2Encrypt(password)
|
||||
|
||||
user := &User{
|
||||
Username: username,
|
||||
Password: hash,
|
||||
Email: email,
|
||||
BindID: getMaxBindId(db) + 1,
|
||||
Token: utils.Sha2Encrypt(email + username),
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`
|
||||
INSERT INTO auth (username, password, email, bind_id, token)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return user.GenerateToken()
|
||||
}
|
||||
|
||||
func Login(c *gin.Context, form LoginForm) (string, error) {
|
||||
db := utils.GetDBFromContext(c)
|
||||
username := strings.TrimSpace(form.Username)
|
||||
password := strings.TrimSpace(form.Password)
|
||||
|
||||
if !utils.All(
|
||||
validateUsernameOrEmail(username),
|
||||
validatePassword(password),
|
||||
) {
|
||||
return "", errors.New("invalid username or password format")
|
||||
}
|
||||
|
||||
hash := utils.Sha2Encrypt(password)
|
||||
|
||||
// get user from db by username (or email) and password
|
||||
var user User
|
||||
if err := db.QueryRow(`
|
||||
SELECT auth.id, auth.username, auth.password FROM auth
|
||||
WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ?
|
||||
`, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil {
|
||||
return "", errors.New("invalid username or password")
|
||||
}
|
||||
|
||||
return user.GenerateToken()
|
||||
}
|
||||
|
||||
func DeepLogin(c *gin.Context, token string) (string, error) {
|
||||
if !useDeeptrain() {
|
||||
return "", errors.New("deeptrain feature is disabled")
|
||||
return "", errors.New("deeptrain mode is disabled")
|
||||
}
|
||||
|
||||
user := Validate(token)
|
||||
@ -91,6 +214,41 @@ func DeepLogin(c *gin.Context, token string) (string, error) {
|
||||
return u.GenerateToken()
|
||||
}
|
||||
|
||||
func Reset(c *gin.Context, form ResetForm) error {
|
||||
db := utils.GetDBFromContext(c)
|
||||
cache := utils.GetCacheFromContext(c)
|
||||
|
||||
email := strings.TrimSpace(form.Email)
|
||||
code := strings.TrimSpace(form.Code)
|
||||
password := strings.TrimSpace(form.Password)
|
||||
|
||||
if !utils.All(
|
||||
validateEmail(email),
|
||||
validateCode(code),
|
||||
validatePassword(password),
|
||||
) {
|
||||
return errors.New("invalid email/code/password format")
|
||||
}
|
||||
|
||||
if !IsEmailExist(db, email) {
|
||||
return errors.New("email is not registered")
|
||||
}
|
||||
|
||||
if !checkCode(c, cache, email, code) {
|
||||
return errors.New("invalid email verification code")
|
||||
}
|
||||
|
||||
hash := utils.Sha2Encrypt(password)
|
||||
|
||||
if _, err := db.Exec(`
|
||||
UPDATE auth SET password = ? WHERE email = ?
|
||||
`, hash, email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) Validate(c *gin.Context) bool {
|
||||
if u.Username == "" || u.Password == "" {
|
||||
return false
|
||||
|
@ -7,10 +7,32 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RegisterForm struct {
|
||||
Username string `form:"username" binding:"required"`
|
||||
Password string `form:"password" binding:"required"`
|
||||
Email string `form:"email" binding:"required"`
|
||||
Code string `form:"code" binding:"required"`
|
||||
}
|
||||
|
||||
type VerifyForm struct {
|
||||
Email string `form:"email" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginForm struct {
|
||||
Username string `form:"username" binding:"required"`
|
||||
Password string `form:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type DeepLoginForm struct {
|
||||
Token string `form:"token" binding:"required"`
|
||||
}
|
||||
|
||||
type ResetForm struct {
|
||||
Email string `form:"email" binding:"required"`
|
||||
Code string `form:"code" binding:"required"`
|
||||
Password string `form:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type BuyForm struct {
|
||||
Quota int `json:"quota" binding:"required"`
|
||||
}
|
||||
@ -111,8 +133,16 @@ func RequireEnterprise(c *gin.Context) *User {
|
||||
return user
|
||||
}
|
||||
|
||||
func LoginAPI(c *gin.Context) {
|
||||
var form DeepLoginForm
|
||||
func RegisterAPI(c *gin.Context) {
|
||||
if useDeeptrain() {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": "this api is not available for deeptrain mode",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var form RegisterForm
|
||||
if err := c.ShouldBind(&form); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
@ -121,7 +151,7 @@ func LoginAPI(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := DeepLogin(c, form.Token)
|
||||
token, err := SignUp(c, form)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
@ -136,6 +166,94 @@ func LoginAPI(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func LoginAPI(c *gin.Context) {
|
||||
var token string
|
||||
var err error
|
||||
|
||||
if useDeeptrain() {
|
||||
var form DeepLoginForm
|
||||
if err := c.ShouldBind(&form); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": "bad request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token, err = DeepLogin(c, form.Token)
|
||||
} else {
|
||||
var form LoginForm
|
||||
if err := c.ShouldBind(&form); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": "bad request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token, err = Login(c, form)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": true,
|
||||
"token": token,
|
||||
})
|
||||
}
|
||||
|
||||
func VerifyAPI(c *gin.Context) {
|
||||
var form VerifyForm
|
||||
if err := c.ShouldBind(&form); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": "bad request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := Verify(c, form.Email); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": true,
|
||||
})
|
||||
}
|
||||
|
||||
func ResetAPI(c *gin.Context) {
|
||||
var form ResetForm
|
||||
if err := c.ShouldBind(&form); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": "bad request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := Reset(c, form); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": true,
|
||||
})
|
||||
}
|
||||
|
||||
func StateAPI(c *gin.Context) {
|
||||
username := utils.GetUserFromContext(c)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
@ -3,6 +3,9 @@ package auth
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Register(app *gin.Engine) {
|
||||
app.POST("/verify", VerifyAPI)
|
||||
app.POST("/reset", ResetAPI)
|
||||
app.POST("/register", RegisterAPI)
|
||||
app.POST("/login", LoginAPI)
|
||||
app.POST("/state", StateAPI)
|
||||
app.GET("/apikey", KeyAPI)
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
BindID int64 `json:"bind_id"`
|
||||
Password string `json:"password"`
|
||||
Token string `json:"token"`
|
||||
@ -56,6 +57,20 @@ func (u *User) GetID(db *sql.DB) int64 {
|
||||
return u.ID
|
||||
}
|
||||
|
||||
func (u *User) GetEmail(db *sql.DB) string {
|
||||
if len(u.Email) > 0 {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
var email sql.NullString
|
||||
if err := db.QueryRow("SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
u.Email = email.String
|
||||
return u.Email
|
||||
}
|
||||
|
||||
func IsUserExist(db *sql.DB, username string) bool {
|
||||
var count int
|
||||
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
|
||||
@ -64,6 +79,22 @@ func IsUserExist(db *sql.DB, username string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func IsEmailExist(db *sql.DB, email string) bool {
|
||||
var count int
|
||||
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil {
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func getMaxBindId(db *sql.DB) int64 {
|
||||
var max int64
|
||||
if err := db.QueryRow("SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil {
|
||||
return 0
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func GetGroup(db *sql.DB, user *User) string {
|
||||
if user == nil {
|
||||
return globals.AnonymousType
|
||||
|
36
auth/validators.go
Normal file
36
auth/validators.go
Normal file
@ -0,0 +1,36 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isInRange(content string, min, max int) bool {
|
||||
content = strings.TrimSpace(content)
|
||||
return len(content) >= min && len(content) <= max
|
||||
}
|
||||
|
||||
func validateUsername(username string) bool {
|
||||
return isInRange(username, 2, 24)
|
||||
}
|
||||
|
||||
func validateUsernameOrEmail(username string) bool {
|
||||
return isInRange(username, 1, 255)
|
||||
}
|
||||
|
||||
func validatePassword(password string) bool {
|
||||
return isInRange(password, 6, 36)
|
||||
}
|
||||
|
||||
func validateEmail(email string) bool {
|
||||
if !isInRange(email, 1, 255) {
|
||||
return false
|
||||
}
|
||||
|
||||
exp := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
return exp.MatchString(email)
|
||||
}
|
||||
|
||||
func validateCode(code string) bool {
|
||||
return isInRange(code, 1, 64)
|
||||
}
|
3
migration/3.8.sql
Normal file
3
migration/3.8.sql
Normal file
@ -0,0 +1,3 @@
|
||||
ALTER TABLE auth
|
||||
ADD COLUMN email VARCHAR(255) UNIQUE,
|
||||
ADD COLUMN is_banned BOOLEAN DEFAULT FALSE;
|
@ -254,3 +254,21 @@ func GetIndexSafe[T any](arr []T, index int) *T {
|
||||
}
|
||||
return &arr[index]
|
||||
}
|
||||
|
||||
func All(arr ...bool) bool {
|
||||
for _, v := range arr {
|
||||
if !v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Any(arr ...bool) bool {
|
||||
for _, v := range arr {
|
||||
if v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -25,19 +25,17 @@ func NewSmtpPoster(host string, port int, username string, password string, from
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SmtpPoster) SendMail(to string, subject string, body string) {
|
||||
func (s *SmtpPoster) SendMail(to string, subject string, body string) error {
|
||||
addr := fmt.Sprintf("%s:%d", s.Host, s.Port)
|
||||
auth := smtp.PlainAuth("", s.From, s.Password, s.Host)
|
||||
err := smtpRequestWithTLS(addr, auth, s.From, []string{to},
|
||||
|
||||
return smtpRequestWithTLS(addr, auth, s.From, []string{to},
|
||||
[]byte(formatMail(map[string]string{
|
||||
"From": fmt.Sprintf("%s <%s>", s.Username, s.From),
|
||||
"To": to,
|
||||
"Subject": subject,
|
||||
"Content-Type": "text/html; charset=utf-8",
|
||||
}, body)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func dial(addr string) (*smtp.Client, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user