diff --git a/api.go b/api/anonymous.go similarity index 82% rename from api.go rename to api/anonymous.go index d45231f..8f7074c 100644 --- a/api.go +++ b/api/anonymous.go @@ -1,6 +1,8 @@ -package main +package api import ( + "chat/connection" + "chat/utils" "context" "fmt" "github.com/gin-gonic/gin" @@ -14,13 +16,8 @@ type AnonymousRequestBody struct { Message string `json:"message" required:"true"` } -type ChatGPTMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - func GetAnonymousResponse(message string) (string, error) { - res, err := Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{ + res, err := utils.Post(viper.GetString("openai.anonymous_endpoint")+"/chat/completions", map[string]string{ "Content-Type": "application/json", "Authorization": "Bearer " + viper.GetString("openai.anonymous"), }, map[string]interface{}{ @@ -41,13 +38,13 @@ func GetAnonymousResponse(message string) (string, error) { } func GetAnonymousResponseWithCache(c context.Context, message string) (string, error) { - res, err := Cache.Get(c, fmt.Sprintf(":chatgpt:%s", message)).Result() + res, err := connection.Cache.Get(c, fmt.Sprintf(":chatgpt:%s", message)).Result() if err != nil || len(res) == 0 { res, err := GetAnonymousResponse(message) if err != nil { return "There was something wrong...", err } - Cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), res, time.Hour*6) + connection.Cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), res, time.Hour*6) return res, nil } return res, nil diff --git a/api/types.go b/api/types.go new file mode 100644 index 0000000..91d8650 --- /dev/null +++ b/api/types.go @@ -0,0 +1,6 @@ +package api + +type ChatGPTMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} diff --git a/allauth.go b/auth/call.go similarity index 78% rename from allauth.go rename to auth/call.go index 51aaf30..dce11cf 100644 --- a/allauth.go +++ b/auth/call.go @@ -1,6 +1,7 @@ -package main +package auth import ( + "chat/utils" "encoding/json" "github.com/spf13/viper" ) @@ -12,12 +13,12 @@ type ValidateUserResponse struct { } func Validate(token string) *ValidateUserResponse { - res, err := Post("https://api.deeptrain.net/app/validate", map[string]string{ + res, err := utils.Post("https://api.deeptrain.net/app/validate", map[string]string{ "Content-Type": "application/json", }, map[string]interface{}{ "password": viper.GetString("auth.access"), "token": token, - "hash": Sha2Encrypt(token + viper.GetString("auth.salt")), + "hash": utils.Sha2Encrypt(token + viper.GetString("auth.salt")), }) if err != nil || res == nil || res.(map[string]interface{})["status"] == false { diff --git a/cache.go b/connection/cache.go similarity index 97% rename from cache.go rename to connection/cache.go index 5cdc3f9..54a644d 100644 --- a/cache.go +++ b/connection/cache.go @@ -1,4 +1,4 @@ -package main +package connection import ( "context" diff --git a/main.go b/main.go index 4a5eed1..99f6748 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "chat/api" + "chat/connection" "github.com/gin-gonic/gin" "github.com/spf13/viper" ) @@ -10,11 +12,11 @@ func main() { if err := viper.ReadInConfig(); err != nil { panic(err) } - ConnectRedis() + connection.ConnectRedis() app := gin.Default() { - app.POST("/api/anonymous", AnonymousAPI) + app.POST("/api/anonymous", api.AnonymousAPI) } if err := app.Run(":" + viper.GetString("server.port")); err != nil { panic(err) diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..792acfb --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "chat/utils" + "github.com/gin-gonic/gin" + "net/http" +) + +var allowedOrigins = []string{ + "https://fystart.cn", + "https://www.fystart.cn", + "https://deeptrain.net", + "https://www.deeptrain.net", + "http://localhost", +} + +func CORSMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + origin := c.Request.Header.Get("Origin") + if utils.Contains(origin, allowedOrigins) { + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") + + if c.Request.Method == "OPTIONS" { + c.Writer.Header().Set("Access-Control-Max-Age", "3600") + c.AbortWithStatus(http.StatusOK) + return + } + } + + c.Next() + } +} diff --git a/utils.go b/utils.go deleted file mode 100644 index 88a516f..0000000 --- a/utils.go +++ /dev/null @@ -1,155 +0,0 @@ -package main - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - crand "crypto/rand" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "io" - "math/rand" - "net/http" - "net/url" - "strconv" - "time" -) - -func Sha2Encrypt(raw string) string { - hash := sha256.Sum256([]byte(raw)) - return hex.EncodeToString(hash[:]) -} - -func AES256Encrypt(key string, data string) (string, error) { - text := []byte(data) - block, err := aes.NewCipher([]byte(key)) - if err != nil { - return "", err - } - - iv := make([]byte, aes.BlockSize) - if _, err := io.ReadFull(crand.Reader, iv); err != nil { - return "", err - } - - encryptor := cipher.NewCFBEncrypter(block, iv) - - ciphertext := make([]byte, len(text)) - encryptor.XORKeyStream(ciphertext, text) - return hex.EncodeToString(ciphertext), nil -} - -func AES256Decrypt(key string, data string) (string, error) { - ciphertext, err := hex.DecodeString(data) - if err != nil { - return "", err - } - - block, err := aes.NewCipher([]byte(key)) - if err != nil { - return "", err - } - - iv := ciphertext[:aes.BlockSize] - ciphertext = ciphertext[aes.BlockSize:] - - decryptor := cipher.NewCFBDecrypter(block, iv) - plaintext := make([]byte, len(ciphertext)) - decryptor.XORKeyStream(plaintext, ciphertext) - - return string(plaintext), nil -} - -func Http(uri string, method string, ptr interface{}, headers map[string]string, body io.Reader) (err error) { - req, err := http.NewRequest(method, uri, body) - if err != nil { - return err - } - for key, value := range headers { - req.Header.Set(key, value) - } - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - - defer resp.Body.Close() - - if err = json.NewDecoder(resp.Body).Decode(ptr); err != nil { - return err - } - return nil -} - -func Get(uri string, headers map[string]string) (data interface{}, err error) { - err = Http(uri, http.MethodGet, &data, headers, nil) - return data, err -} - -func Post(uri string, headers map[string]string, body interface{}) (data interface{}, err error) { - var form io.Reader - if buffer, err := json.Marshal(body); err == nil { - form = bytes.NewBuffer(buffer) - } - err = Http(uri, http.MethodPost, &data, headers, form) - return data, err -} - -func PostForm(uri string, body map[string]interface{}) (data map[string]interface{}, err error) { - client := &http.Client{} - form := make(url.Values) - for key, value := range body { - form[key] = []string{value.(string)} - } - res, err := client.PostForm(uri, form) - if err != nil { - return nil, err - } - content, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } - - if err = json.Unmarshal(content, &data); err != nil { - return nil, err - } - - return data, nil -} - -func GenerateCode(length int) string { - var code string - for i := 0; i < length; i++ { - code += strconv.Itoa(rand.Intn(10)) - } - return code -} - -func GenerateChar(length int) string { - const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - result := make([]byte, length) - for i := 0; i < length; i++ { - result[i] = charset[rand.Intn(len(charset))] - } - return string(result) -} - -func ConvertTime(t []uint8) *time.Time { - val, err := time.Parse("2006-01-02 15:04:05", string(t)) - if err != nil { - return nil - } - return &val -} - -func Contains[T comparable](value T, slice []T) bool { - for _, item := range slice { - if item == value { - return true - } - } - return false -} diff --git a/utils/base.go b/utils/base.go new file mode 100644 index 0000000..f970787 --- /dev/null +++ b/utils/base.go @@ -0,0 +1,10 @@ +package utils + +func Contains[T comparable](value T, slice []T) bool { + for _, item := range slice { + if item == value { + return true + } + } + return false +} diff --git a/utils/char.go b/utils/char.go new file mode 100644 index 0000000..fae5274 --- /dev/null +++ b/utils/char.go @@ -0,0 +1,32 @@ +package utils + +import ( + "math/rand" + "strconv" + "time" +) + +func GenerateCode(length int) string { + var code string + for i := 0; i < length; i++ { + code += strconv.Itoa(rand.Intn(10)) + } + return code +} + +func GenerateChar(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + result := make([]byte, length) + for i := 0; i < length; i++ { + result[i] = charset[rand.Intn(len(charset))] + } + return string(result) +} + +func ConvertTime(t []uint8) *time.Time { + val, err := time.Parse("2006-01-02 15:04:05", string(t)) + if err != nil { + return nil + } + return &val +} diff --git a/utils/encrypt.go b/utils/encrypt.go new file mode 100644 index 0000000..52b1693 --- /dev/null +++ b/utils/encrypt.go @@ -0,0 +1,55 @@ +package utils + +import ( + "crypto/aes" + "crypto/cipher" + crand "crypto/rand" + "crypto/sha256" + "encoding/hex" + "io" +) + +func Sha2Encrypt(raw string) string { + hash := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(hash[:]) +} + +func AES256Encrypt(key string, data string) (string, error) { + text := []byte(data) + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(crand.Reader, iv); err != nil { + return "", err + } + + encryptor := cipher.NewCFBEncrypter(block, iv) + + ciphertext := make([]byte, len(text)) + encryptor.XORKeyStream(ciphertext, text) + return hex.EncodeToString(ciphertext), nil +} + +func AES256Decrypt(key string, data string) (string, error) { + ciphertext, err := hex.DecodeString(data) + if err != nil { + return "", err + } + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + + decryptor := cipher.NewCFBDecrypter(block, iv) + plaintext := make([]byte, len(ciphertext)) + decryptor.XORKeyStream(plaintext, ciphertext) + + return string(plaintext), nil +} diff --git a/utils/net.go b/utils/net.go new file mode 100644 index 0000000..3cd9ae5 --- /dev/null +++ b/utils/net.go @@ -0,0 +1,68 @@ +package utils + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/url" +) + +func Http(uri string, method string, ptr interface{}, headers map[string]string, body io.Reader) (err error) { + req, err := http.NewRequest(method, uri, body) + if err != nil { + return err + } + for key, value := range headers { + req.Header.Set(key, value) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + defer resp.Body.Close() + + if err = json.NewDecoder(resp.Body).Decode(ptr); err != nil { + return err + } + return nil +} + +func Get(uri string, headers map[string]string) (data interface{}, err error) { + err = Http(uri, http.MethodGet, &data, headers, nil) + return data, err +} + +func Post(uri string, headers map[string]string, body interface{}) (data interface{}, err error) { + var form io.Reader + if buffer, err := json.Marshal(body); err == nil { + form = bytes.NewBuffer(buffer) + } + err = Http(uri, http.MethodPost, &data, headers, form) + return data, err +} + +func PostForm(uri string, body map[string]interface{}) (data map[string]interface{}, err error) { + client := &http.Client{} + form := make(url.Values) + for key, value := range body { + form[key] = []string{value.(string)} + } + res, err := client.PostForm(uri, form) + if err != nil { + return nil, err + } + content, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + if err = json.Unmarshal(content, &data); err != nil { + return nil, err + } + + return data, nil +}