mirror of
https://github.com/coaidev/coai.git
synced 2025-05-22 14:30:14 +09:00
add image generation feature
This commit is contained in:
parent
68b49f6a09
commit
0d7085e7f4
@ -52,6 +52,8 @@ openai:
|
|||||||
anonymous_endpoint: https://api.openai.com/v1
|
anonymous_endpoint: https://api.openai.com/v1
|
||||||
user: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx
|
user: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx
|
||||||
user_endpoint: https://api.openai.com/v1
|
user_endpoint: https://api.openai.com/v1
|
||||||
|
image: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx
|
||||||
|
image_endpoint: https://api.openai.com/v1
|
||||||
|
|
||||||
mysql:
|
mysql:
|
||||||
host: localhost
|
host: localhost
|
||||||
|
85
api/chat.go
85
api/chat.go
@ -7,9 +7,12 @@ import (
|
|||||||
"chat/types"
|
"chat/types"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WebsocketAuthForm struct {
|
type WebsocketAuthForm struct {
|
||||||
@ -21,6 +24,63 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon
|
|||||||
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
|
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TextChat(conn *websocket.Conn, instance *conversation.Conversation) string {
|
||||||
|
keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
|
||||||
|
|
||||||
|
msg := ""
|
||||||
|
StreamRequest("gpt-3.5-turbo-16k-0613", segment, 2000, func(resp string) {
|
||||||
|
msg += resp
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
|
Message: resp,
|
||||||
|
End: false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
if msg == "" {
|
||||||
|
msg = "There was something wrong... Please try again later."
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
|
Message: msg,
|
||||||
|
End: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true})
|
||||||
|
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
|
||||||
|
// format: /image a cat
|
||||||
|
data := strings.TrimSpace(instance.GetLatestMessage()[6:])
|
||||||
|
if len(data) == 0 {
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
|
Message: "Please provide description for the image.",
|
||||||
|
End: true,
|
||||||
|
})
|
||||||
|
return "Please provide description for the image."
|
||||||
|
}
|
||||||
|
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
|
Message: "Generating image...\n",
|
||||||
|
End: false,
|
||||||
|
})
|
||||||
|
url, err := GetImageWithUserLimit(user, data, db, cache)
|
||||||
|
if err != nil {
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
|
Message: err.Error(),
|
||||||
|
End: true,
|
||||||
|
})
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
markdown := fmt.Sprintln("")
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
|
Message: markdown,
|
||||||
|
Keyword: "image",
|
||||||
|
End: true,
|
||||||
|
})
|
||||||
|
return markdown
|
||||||
|
}
|
||||||
|
|
||||||
func ChatAPI(c *gin.Context) {
|
func ChatAPI(c *gin.Context) {
|
||||||
// websocket connection
|
// websocket connection
|
||||||
upgrader := websocket.Upgrader{
|
upgrader := websocket.Upgrader{
|
||||||
@ -85,27 +145,14 @@ func ChatAPI(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if instance.HandleMessage(db, message) {
|
if instance.HandleMessage(db, message) {
|
||||||
keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
|
var msg string
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
|
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
|
||||||
|
cache := c.MustGet("cache").(*redis.Client)
|
||||||
msg := ""
|
msg = ImageChat(conn, instance, user, db, cache)
|
||||||
StreamRequest("gpt-3.5-turbo-16k-0613", segment, 2000, func(resp string) {
|
} else {
|
||||||
msg += resp
|
msg = TextChat(conn, instance)
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
|
||||||
Message: resp,
|
|
||||||
End: false,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if msg == "" {
|
|
||||||
msg = "There was something wrong... Please try again later."
|
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
|
||||||
Message: msg,
|
|
||||||
End: false,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
instance.SaveResponse(db, msg)
|
instance.SaveResponse(db, msg)
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
64
api/image.go
Normal file
64
api/image.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/auth"
|
||||||
|
"chat/types"
|
||||||
|
"chat/utils"
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetImage(prompt string) (string, error) {
|
||||||
|
res, err := utils.Post(viper.GetString("openai.image_endpoint")+"/images/generations", map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer " + GetRandomKey(viper.GetString("openai.image")),
|
||||||
|
}, types.ChatGPTImageRequest{
|
||||||
|
Prompt: prompt,
|
||||||
|
Size: "512x512",
|
||||||
|
N: 1,
|
||||||
|
})
|
||||||
|
if err != nil || res == nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err, ok := res.(map[string]interface{})["error"]; ok {
|
||||||
|
return "", fmt.Errorf(err.(map[string]interface{})["message"].(string))
|
||||||
|
}
|
||||||
|
data := res.(map[string]interface{})["data"].([]interface{})[0].(map[string]interface{})["url"]
|
||||||
|
return data.(string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImageWithCache(ctx context.Context, prompt string, cache *redis.Client) (string, error) {
|
||||||
|
res, err := cache.Get(ctx, fmt.Sprintf(":image:%s", prompt)).Result()
|
||||||
|
if err != nil || len(res) == 0 || res == "" {
|
||||||
|
res, err := GetImage(prompt)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.Set(ctx, fmt.Sprintf(":image:%s", prompt), res, time.Hour*6)
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *redis.Client) (string, error) {
|
||||||
|
// 3 images one day per user (count by cache)
|
||||||
|
res, err := cache.Get(context.Background(), fmt.Sprintf(":imagelimit:%d", user.GetID(db))).Result()
|
||||||
|
if err != nil || len(res) == 0 || res == "" {
|
||||||
|
cache.Set(context.Background(), fmt.Sprintf(":imagelimit:%d", user.GetID(db)), "1", time.Hour*24)
|
||||||
|
return GetImageWithCache(context.Background(), prompt, cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == "3" {
|
||||||
|
return "", fmt.Errorf("you have reached your limit of 3 images per day")
|
||||||
|
} else {
|
||||||
|
cache.Set(context.Background(), fmt.Sprintf(":imagelimit:%d", user.GetID(db)), fmt.Sprintf("%d", utils.ToInt(res)+1), time.Hour*24)
|
||||||
|
return GetImageWithCache(context.Background(), prompt, cache)
|
||||||
|
}
|
||||||
|
}
|
@ -212,7 +212,7 @@ export class Conversation {
|
|||||||
public dynamicTypingEffect(index: number, content: Ref<string>, keyword: Ref<string>, end: Ref<boolean>): void {
|
public dynamicTypingEffect(index: number, content: Ref<string>, keyword: Ref<string>, end: Ref<boolean>): void {
|
||||||
let cursor = 0;
|
let cursor = 0;
|
||||||
const interval = setInterval(() => {
|
const interval = setInterval(() => {
|
||||||
keyword.value && (this.messages[index].keyword = keyword.value);
|
if (keyword.value && keyword.value !== "image") this.messages[index].keyword = keyword.value;
|
||||||
if (end.value && cursor >= content.value.length) {
|
if (end.value && cursor >= content.value.length) {
|
||||||
this.messages[index].content = content.value;
|
this.messages[index].content = content.value;
|
||||||
this.state.value = false;
|
this.state.value = false;
|
||||||
@ -221,6 +221,9 @@ export class Conversation {
|
|||||||
}
|
}
|
||||||
if (cursor >= content.value.length) return;
|
if (cursor >= content.value.length) return;
|
||||||
cursor++;
|
cursor++;
|
||||||
|
if (keyword.value === "image") {
|
||||||
|
cursor = content.value.length;
|
||||||
|
}
|
||||||
this.messages[index].content = content.value.substring(0, cursor);
|
this.messages[index].content = content.value.substring(0, cursor);
|
||||||
this.refresh && this.refresh();
|
this.refresh && this.refresh();
|
||||||
}, 20);
|
}, 20);
|
||||||
|
@ -11,6 +11,8 @@ openai:
|
|||||||
anonymous_endpoint: https://api.openai.com/v1
|
anonymous_endpoint: https://api.openai.com/v1
|
||||||
user: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx
|
user: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx
|
||||||
user_endpoint: https://api.openai.com/v1
|
user_endpoint: https://api.openai.com/v1
|
||||||
|
image: sk-xxxxxx|sk-xxxxxx|sk-xxxxxx
|
||||||
|
image_endpoint: https://api.openai.com/v1
|
||||||
|
|
||||||
mysql:
|
mysql:
|
||||||
host: localhost
|
host: localhost
|
||||||
|
@ -132,6 +132,10 @@ func (c *Conversation) HandleMessage(db *sql.DB, data []byte) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) GetLatestMessage() string {
|
||||||
|
return c.Message[len(c.Message)-1].Content
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conversation) SaveResponse(db *sql.DB, message string) {
|
func (c *Conversation) SaveResponse(db *sql.DB, message string) {
|
||||||
c.AddMessageFromAssistant(message)
|
c.AddMessageFromAssistant(message)
|
||||||
c.SaveConversation(db)
|
c.SaveConversation(db)
|
||||||
|
@ -12,6 +12,12 @@ type ChatGPTRequest struct {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatGPTImageRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
N int `json:"n"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatGPTStreamResponse struct {
|
type ChatGPTStreamResponse struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
|
@ -36,3 +36,11 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
|||||||
err = json.Unmarshal(data, &form)
|
err = json.Unmarshal(data, &form)
|
||||||
return form, err
|
return form, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToInt(value string) int {
|
||||||
|
if res, err := strconv.Atoi(value); err == nil {
|
||||||
|
return res
|
||||||
|
} else {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user