diff --git a/README.md b/README.md index 9ea2c39..8764b45 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ openai: anonymous_endpoint: https://api.openai.com/v1 user: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx user_endpoint: https://api.openai.com/v1 + image: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx + image_endpoint: https://api.openai.com/v1 mysql: host: localhost diff --git a/api/chat.go b/api/chat.go index 9cb3a2f..db54d76 100644 --- a/api/chat.go +++ b/api/chat.go @@ -7,9 +7,12 @@ import ( "chat/types" "chat/utils" "database/sql" + "fmt" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" "net/http" + "strings" ) type WebsocketAuthForm struct { @@ -21,6 +24,63 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon _ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message))) } +func TextChat(conn *websocket.Conn, instance *conversation.Conversation) string { + keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true) + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) + + msg := "" + StreamRequest("gpt-3.5-turbo-16k-0613", segment, 2000, func(resp string) { + msg += resp + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: resp, + End: false, + }) + }) + if msg == "" { + msg = "There was something wrong... Please try again later." + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: msg, + End: false, + }) + } + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true}) + + return msg +} + +func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string { + // format: /image a cat + data := strings.TrimSpace(instance.GetLatestMessage()[6:]) + if len(data) == 0 { + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: "Please provide description for the image.", + End: true, + }) + return "Please provide description for the image." + } + + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: "Generating image...\n", + End: false, + }) + url, err := GetImageWithUserLimit(user, data, db, cache) + if err != nil { + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: err.Error(), + End: true, + }) + return err.Error() + } + + markdown := fmt.Sprintln("![image](", url, ")") + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: markdown, + Keyword: "image", + End: true, + }) + return markdown +} + func ChatAPI(c *gin.Context) { // websocket connection upgrader := websocket.Upgrader{ @@ -85,27 +145,14 @@ func ChatAPI(c *gin.Context) { return } if instance.HandleMessage(db, message) { - keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true) - SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) - - msg := "" - StreamRequest("gpt-3.5-turbo-16k-0613", segment, 2000, func(resp string) { - msg += resp - SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ - Message: resp, - End: false, - }) - }) - if msg == "" { - msg = "There was something wrong... Please try again later." - SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ - Message: msg, - End: false, - }) + var msg string + if strings.HasPrefix(instance.GetLatestMessage(), "/image") { + cache := c.MustGet("cache").(*redis.Client) + msg = ImageChat(conn, instance, user, db, cache) + } else { + msg = TextChat(conn, instance) } - instance.SaveResponse(db, msg) - SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true}) } } } diff --git a/api/image.go b/api/image.go new file mode 100644 index 0000000..d95faf5 --- /dev/null +++ b/api/image.go @@ -0,0 +1,64 @@ +package api + +import ( + "chat/auth" + "chat/types" + "chat/utils" + "context" + "database/sql" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/spf13/viper" + "time" +) + +func GetImage(prompt string) (string, error) { + res, err := utils.Post(viper.GetString("openai.image_endpoint")+"/images/generations", map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + GetRandomKey(viper.GetString("openai.image")), + }, types.ChatGPTImageRequest{ + Prompt: prompt, + Size: "512x512", + N: 1, + }) + if err != nil || res == nil { + return "", err + } + + if err, ok := res.(map[string]interface{})["error"]; ok { + return "", fmt.Errorf(err.(map[string]interface{})["message"].(string)) + } + data := res.(map[string]interface{})["data"].([]interface{})[0].(map[string]interface{})["url"] + return data.(string), nil +} + +func GetImageWithCache(ctx context.Context, prompt string, cache *redis.Client) (string, error) { + res, err := cache.Get(ctx, fmt.Sprintf(":image:%s", prompt)).Result() + if err != nil || len(res) == 0 || res == "" { + res, err := GetImage(prompt) + if err != nil { + return "", err + } + + cache.Set(ctx, fmt.Sprintf(":image:%s", prompt), res, time.Hour*6) + return res, nil + } + + return res, nil +} + +func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *redis.Client) (string, error) { + // 3 images one day per user (count by cache) + res, err := cache.Get(context.Background(), fmt.Sprintf(":imagelimit:%d", user.GetID(db))).Result() + if err != nil || len(res) == 0 || res == "" { + cache.Set(context.Background(), fmt.Sprintf(":imagelimit:%d", user.GetID(db)), "1", time.Hour*24) + return GetImageWithCache(context.Background(), prompt, cache) + } + + if res == "3" { + return "", fmt.Errorf("you have reached your limit of 3 images per day") + } else { + cache.Set(context.Background(), fmt.Sprintf(":imagelimit:%d", user.GetID(db)), fmt.Sprintf("%d", utils.ToInt(res)+1), time.Hour*24) + return GetImageWithCache(context.Background(), prompt, cache) + } +} diff --git a/app/src/assets/script/conversation.ts b/app/src/assets/script/conversation.ts index 54da76f..c82e607 100644 --- a/app/src/assets/script/conversation.ts +++ b/app/src/assets/script/conversation.ts @@ -212,7 +212,7 @@ export class Conversation { public dynamicTypingEffect(index: number, content: Ref, keyword: Ref, end: Ref): void { let cursor = 0; const interval = setInterval(() => { - keyword.value && (this.messages[index].keyword = keyword.value); + if (keyword.value && keyword.value !== "image") this.messages[index].keyword = keyword.value; if (end.value && cursor >= content.value.length) { this.messages[index].content = content.value; this.state.value = false; @@ -221,6 +221,9 @@ export class Conversation { } if (cursor >= content.value.length) return; cursor++; + if (keyword.value === "image") { + cursor = content.value.length; + } this.messages[index].content = content.value.substring(0, cursor); this.refresh && this.refresh(); }, 20); diff --git a/config.example.yaml b/config.example.yaml index 295c95b..4574610 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -11,6 +11,8 @@ openai: anonymous_endpoint: https://api.openai.com/v1 user: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx user_endpoint: https://api.openai.com/v1 + image: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx + image_endpoint: https://api.openai.com/v1 mysql: host: localhost diff --git a/conversation/conversation.go b/conversation/conversation.go index 8278cce..a01f684 100644 --- a/conversation/conversation.go +++ b/conversation/conversation.go @@ -132,6 +132,10 @@ func (c *Conversation) HandleMessage(db *sql.DB, data []byte) bool { return true } +func (c *Conversation) GetLatestMessage() string { + return c.Message[len(c.Message)-1].Content +} + func (c *Conversation) SaveResponse(db *sql.DB, message string) { c.AddMessageFromAssistant(message) c.SaveConversation(db) diff --git a/types/types.go b/types/types.go index 0a53e98..0ad714c 100644 --- a/types/types.go +++ b/types/types.go @@ -12,6 +12,12 @@ type ChatGPTRequest struct { Stream bool `json:"stream"` } +type ChatGPTImageRequest struct { + Prompt string `json:"prompt"` + Size string `json:"size"` + N int `json:"n"` +} + type ChatGPTStreamResponse struct { ID string `json:"id"` Object string `json:"object"` diff --git a/utils/char.go b/utils/char.go index 1df77df..b176be1 100644 --- a/utils/char.go +++ b/utils/char.go @@ -36,3 +36,11 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) { err = json.Unmarshal(data, &form) return form, err } + +func ToInt(value string) int { + if res, err := strconv.Atoi(value); err == nil { + return res + } else { + return 0 + } +}