mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
feat: Google Imagen image generation
This commit is contained in:
parent
8f0f9a0fda
commit
1856dd0312
@ -4,16 +4,22 @@ import (
|
||||
adaptercommon "chat/adapter/common"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var geminiMaxImages = 16
|
||||
|
||||
func (c *ChatInstance) GetChatEndpoint(model string) string {
|
||||
func (c *ChatInstance) GetChatEndpoint(model string, stream bool) string {
|
||||
if model == globals.ChatBison001 {
|
||||
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
|
||||
}
|
||||
|
||||
if stream {
|
||||
return fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s", c.Endpoint, model, c.ApiKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
|
||||
}
|
||||
|
||||
@ -88,7 +94,7 @@ func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) {
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
|
||||
uri := c.GetChatEndpoint(props.Model)
|
||||
uri := c.GetChatEndpoint(props.Model, false)
|
||||
|
||||
if props.Model == globals.ChatBison001 {
|
||||
data, err := utils.Post(uri, map[string]string{
|
||||
@ -112,18 +118,89 @@ func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string
|
||||
return c.GetGeminiChatResponse(data)
|
||||
}
|
||||
|
||||
// CreateStreamChatRequest is the mock stream request for palm2
|
||||
// tips: palm2 does not support stream request
|
||||
// CreateStreamChatRequest is the stream request for palm2
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error {
|
||||
response, err := c.CreateChatRequest(props)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range utils.SplitItem(response, " ") {
|
||||
if err := callback(&globals.Chunk{Content: item}); err != nil {
|
||||
// Handle imagen models
|
||||
if globals.IsGoogleImagenModel(props.Model) {
|
||||
response, err := c.CreateImage(props)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return callback(&globals.Chunk{Content: response})
|
||||
}
|
||||
|
||||
// Handle chat models
|
||||
if props.Model == globals.ChatBison001 {
|
||||
response, err := c.CreateChatRequest(props)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range utils.SplitItem(response, " ") {
|
||||
if err := callback(&globals.Chunk{Content: item}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ticks := 0
|
||||
scanErr := utils.EventScanner(&utils.EventScannerProps{
|
||||
Method: "POST",
|
||||
Uri: c.GetChatEndpoint(props.Model, true),
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
Body: c.GetGeminiChatBody(props),
|
||||
Callback: func(data string) error {
|
||||
ticks += 1
|
||||
|
||||
if form := utils.UnmarshalForm[GeminiStreamResponse](data); form != nil {
|
||||
if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 {
|
||||
return callback(&globals.Chunk{
|
||||
Content: form.Candidates[0].Content.Parts[0].Text,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if form := utils.UnmarshalForm[GeminiChatErrorResponse](data); form != nil {
|
||||
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}, props.Proxy)
|
||||
|
||||
if scanErr != nil {
|
||||
if scanErr.Error != nil && strings.Contains(scanErr.Error.Error(), "status code: 404") {
|
||||
// downgrade to non-stream request
|
||||
response, err := c.CreateChatRequest(props)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return callback(&globals.Chunk{Content: response})
|
||||
}
|
||||
|
||||
if scanErr.Body != "" {
|
||||
if form := utils.UnmarshalForm[GeminiChatErrorResponse](scanErr.Body); form != nil {
|
||||
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
|
||||
}
|
||||
return fmt.Errorf("gemini error: %s", scanErr.Body)
|
||||
}
|
||||
return fmt.Errorf("gemini error: %v", scanErr.Error)
|
||||
}
|
||||
|
||||
if ticks == 0 {
|
||||
return errors.New("no response")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string {
|
||||
if len(props.Message) == 0 {
|
||||
return ""
|
||||
}
|
||||
return props.Message[len(props.Message)-1].Content
|
||||
}
|
||||
|
82
adapter/palm2/image.go
Normal file
82
adapter/palm2/image.go
Normal file
@ -0,0 +1,82 @@
|
||||
package palm2
|
||||
|
||||
import (
|
||||
adaptercommon "chat/adapter/common"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ImageProps struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Proxy globals.ProxyConfig
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetImageEndpoint(model string) string {
|
||||
return fmt.Sprintf("%s/v1beta/models/%s:predict?key=%s", c.Endpoint, model, c.ApiKey)
|
||||
}
|
||||
|
||||
// CreateImageRequest will create a gemini imagen from prompt, return base64 of image and error
|
||||
func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
|
||||
res, err := utils.Post(
|
||||
c.GetImageEndpoint(props.Model),
|
||||
map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
ImageRequest{
|
||||
Instances: []ImageInstance{
|
||||
{
|
||||
Prompt: props.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: ImageParameters{
|
||||
SampleCount: 1,
|
||||
AspectRatio: "1:1",
|
||||
PersonGeneration: "allow_adult",
|
||||
},
|
||||
},
|
||||
props.Proxy,
|
||||
)
|
||||
|
||||
if err != nil || res == nil {
|
||||
return "", fmt.Errorf("gemini error: %s", err.Error())
|
||||
}
|
||||
|
||||
data := utils.MapToStruct[ImageResponse](res)
|
||||
if data == nil {
|
||||
return "", fmt.Errorf("gemini error: cannot parse response")
|
||||
}
|
||||
|
||||
if len(data.Predictions) == 0 {
|
||||
return "", fmt.Errorf("gemini error: no image generated")
|
||||
}
|
||||
|
||||
return data.Predictions[0].BytesBase64Encoded, nil
|
||||
}
|
||||
|
||||
// CreateImage will create a gemini imagen from prompt, return markdown of image
|
||||
func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) {
|
||||
if !globals.IsGoogleImagenModel(props.Model) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
base64Data, err := c.CreateImageRequest(ImageProps{
|
||||
Model: props.Model,
|
||||
Prompt: c.GetLatestPrompt(props),
|
||||
Proxy: props.Proxy,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "safety") {
|
||||
return err.Error(), nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Convert base64 to data URL format
|
||||
dataUrl := fmt.Sprintf("data:image/png;base64,%s", base64Data)
|
||||
url := utils.StoreImage(dataUrl)
|
||||
return utils.GetImageMarkdown(url), nil
|
||||
}
|
@ -70,3 +70,40 @@ type GeminiChatErrorResponse struct {
|
||||
Status string `json:"status"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
type GeminiStreamResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
Role string `json:"role"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
|
||||
// ImageRequest is the native http request body for imagen
|
||||
type ImageRequest struct {
|
||||
Instances []ImageInstance `json:"instances"`
|
||||
Parameters ImageParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type ImageInstance struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ImageParameters struct {
|
||||
SampleCount int `json:"sampleCount,omitempty"`
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
PersonGeneration string `json:"personGeneration,omitempty"`
|
||||
}
|
||||
|
||||
// ImageResponse is the native http response body for imagen
|
||||
type ImageResponse struct {
|
||||
Predictions []ImagePrediction `json:"predictions"`
|
||||
}
|
||||
|
||||
type ImagePrediction struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
}
|
||||
|
@ -61,86 +61,99 @@ func OriginIsOpen(c *gin.Context) bool {
|
||||
}
|
||||
|
||||
const (
|
||||
GPT3Turbo = "gpt-3.5-turbo"
|
||||
GPT3TurboInstruct = "gpt-3.5-turbo-instruct"
|
||||
GPT3Turbo0613 = "gpt-3.5-turbo-0613"
|
||||
GPT3Turbo0301 = "gpt-3.5-turbo-0301"
|
||||
GPT3Turbo1106 = "gpt-3.5-turbo-1106"
|
||||
GPT3Turbo0125 = "gpt-3.5-turbo-0125"
|
||||
GPT3Turbo16k = "gpt-3.5-turbo-16k"
|
||||
GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613"
|
||||
GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4All = "gpt-4-all"
|
||||
GPT4Vision = "gpt-4-v"
|
||||
GPT4Dalle = "gpt-4-dalle"
|
||||
GPT40314 = "gpt-4-0314"
|
||||
GPT40613 = "gpt-4-0613"
|
||||
GPT41106Preview = "gpt-4-1106-preview"
|
||||
GPT40125Preview = "gpt-4-0125-preview"
|
||||
GPT4TurboPreview = "gpt-4-turbo-preview"
|
||||
GPT4VisionPreview = "gpt-4-vision-preview"
|
||||
GPT4Turbo = "gpt-4-turbo"
|
||||
GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09"
|
||||
GPT41106VisionPreview = "gpt-4-1106-vision-preview"
|
||||
GPT432k = "gpt-4-32k"
|
||||
GPT432k0314 = "gpt-4-32k-0314"
|
||||
GPT432k0613 = "gpt-4-32k-0613"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4O20240513 = "gpt-4o-2024-05-13"
|
||||
Dalle = "dalle"
|
||||
Dalle2 = "dall-e-2"
|
||||
Dalle3 = "dall-e-3"
|
||||
Claude1 = "claude-1"
|
||||
Claude1100k = "claude-1.3"
|
||||
Claude2 = "claude-1-100k"
|
||||
Claude2100k = "claude-2"
|
||||
Claude2200k = "claude-2.1"
|
||||
Claude3 = "claude-3"
|
||||
ClaudeSlack = "claude-slack"
|
||||
SparkDeskLite = "spark-desk-lite"
|
||||
SparkDeskPro = "spark-desk-pro"
|
||||
SparkDeskPro128K = "spark-desk-pro-128k"
|
||||
SparkDeskMax = "spark-desk-max"
|
||||
SparkDeskMax32K = "spark-desk-max-32k"
|
||||
SparkDeskV4Ultra = "spark-desk-4.0-ultra"
|
||||
ChatBison001 = "chat-bison-001"
|
||||
GeminiPro = "gemini-pro"
|
||||
GeminiProVision = "gemini-pro-vision"
|
||||
Gemini15ProLatest = "gemini-1.5-pro-latest"
|
||||
Gemini15FlashLatest = "gemini-1.5-flash-latest"
|
||||
BingCreative = "bing-creative"
|
||||
BingBalanced = "bing-balanced"
|
||||
BingPrecise = "bing-precise"
|
||||
ZhiPuChatGLM4 = "glm-4"
|
||||
ZhiPuChatGLM4Vision = "glm-4v"
|
||||
ZhiPuChatGLM3Turbo = "glm-3-turbo"
|
||||
ZhiPuChatGLMTurbo = "zhipu-chatglm-turbo"
|
||||
ZhiPuChatGLMPro = "zhipu-chatglm-pro"
|
||||
ZhiPuChatGLMStd = "zhipu-chatglm-std"
|
||||
ZhiPuChatGLMLite = "zhipu-chatglm-lite"
|
||||
QwenTurbo = "qwen-turbo"
|
||||
QwenPlus = "qwen-plus"
|
||||
QwenTurboNet = "qwen-turbo-net"
|
||||
QwenPlusNet = "qwen-plus-net"
|
||||
Midjourney = "midjourney"
|
||||
MidjourneyFast = "midjourney-fast"
|
||||
MidjourneyTurbo = "midjourney-turbo"
|
||||
Hunyuan = "hunyuan"
|
||||
GPT360V9 = "360-gpt-v9"
|
||||
Baichuan53B = "baichuan-53b"
|
||||
SkylarkLite = "skylark-lite-public"
|
||||
SkylarkPlus = "skylark-plus-public"
|
||||
SkylarkPro = "skylark-pro-public"
|
||||
SkylarkChat = "skylark-chat"
|
||||
DeepseekV3 = "deepseek-chat"
|
||||
DeepseekR1 = "deepseek-reasoner"
|
||||
GPT3Turbo = "gpt-3.5-turbo"
|
||||
GPT3TurboInstruct = "gpt-3.5-turbo-instruct"
|
||||
GPT3Turbo0613 = "gpt-3.5-turbo-0613"
|
||||
GPT3Turbo0301 = "gpt-3.5-turbo-0301"
|
||||
GPT3Turbo1106 = "gpt-3.5-turbo-1106"
|
||||
GPT3Turbo0125 = "gpt-3.5-turbo-0125"
|
||||
GPT3Turbo16k = "gpt-3.5-turbo-16k"
|
||||
GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613"
|
||||
GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4All = "gpt-4-all"
|
||||
GPT4Vision = "gpt-4-v"
|
||||
GPT4Dalle = "gpt-4-dalle"
|
||||
GPT40314 = "gpt-4-0314"
|
||||
GPT40613 = "gpt-4-0613"
|
||||
GPT41106Preview = "gpt-4-1106-preview"
|
||||
GPT40125Preview = "gpt-4-0125-preview"
|
||||
GPT4TurboPreview = "gpt-4-turbo-preview"
|
||||
GPT4VisionPreview = "gpt-4-vision-preview"
|
||||
GPT4Turbo = "gpt-4-turbo"
|
||||
GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09"
|
||||
GPT41106VisionPreview = "gpt-4-1106-vision-preview"
|
||||
GPT432k = "gpt-4-32k"
|
||||
GPT432k0314 = "gpt-4-32k-0314"
|
||||
GPT432k0613 = "gpt-4-32k-0613"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4O20240513 = "gpt-4o-2024-05-13"
|
||||
Dalle = "dalle"
|
||||
Dalle2 = "dall-e-2"
|
||||
Dalle3 = "dall-e-3"
|
||||
Claude1 = "claude-1"
|
||||
Claude1100k = "claude-1.3"
|
||||
Claude2 = "claude-1-100k"
|
||||
Claude2100k = "claude-2"
|
||||
Claude2200k = "claude-2.1"
|
||||
Claude3 = "claude-3"
|
||||
ClaudeSlack = "claude-slack"
|
||||
SparkDeskLite = "spark-desk-lite"
|
||||
SparkDeskPro = "spark-desk-pro"
|
||||
SparkDeskPro128K = "spark-desk-pro-128k"
|
||||
SparkDeskMax = "spark-desk-max"
|
||||
SparkDeskMax32K = "spark-desk-max-32k"
|
||||
SparkDeskV4Ultra = "spark-desk-4.0-ultra"
|
||||
ChatBison001 = "chat-bison-001"
|
||||
GeminiPro = "gemini-pro"
|
||||
GeminiProVision = "gemini-pro-vision"
|
||||
Gemini15ProLatest = "gemini-1.5-pro-latest"
|
||||
Gemini15FlashLatest = "gemini-1.5-flash-latest"
|
||||
Gemini20ProExp = "gemini-2.0-pro-exp-02-05"
|
||||
Gemini20Flash = "gemini-2.0-flash"
|
||||
Gemini20FlashExp = "gemini-2.0-flash-exp"
|
||||
Gemini20Flash001 = "gemini-2.0-flash-001"
|
||||
Gemini20FlashThinkingExp = "gemini-2.0-flash-thinking-exp-01-21"
|
||||
Gemini20FlashLitePreview = "gemini-2.0-flash-lite-preview-02-05"
|
||||
Gemini20FlashThinkingExp1219 = "gemini-2.0-flash-thinking-exp-1219"
|
||||
GeminiExp1206 = "gemini-exp-1206"
|
||||
GoogleImagen002 = "imagen-3.0-generate-002"
|
||||
BingCreative = "bing-creative"
|
||||
BingBalanced = "bing-balanced"
|
||||
BingPrecise = "bing-precise"
|
||||
ZhiPuChatGLM4 = "glm-4"
|
||||
ZhiPuChatGLM4Vision = "glm-4v"
|
||||
ZhiPuChatGLM3Turbo = "glm-3-turbo"
|
||||
ZhiPuChatGLMTurbo = "zhipu-chatglm-turbo"
|
||||
ZhiPuChatGLMPro = "zhipu-chatglm-pro"
|
||||
ZhiPuChatGLMStd = "zhipu-chatglm-std"
|
||||
ZhiPuChatGLMLite = "zhipu-chatglm-lite"
|
||||
QwenTurbo = "qwen-turbo"
|
||||
QwenPlus = "qwen-plus"
|
||||
QwenTurboNet = "qwen-turbo-net"
|
||||
QwenPlusNet = "qwen-plus-net"
|
||||
Midjourney = "midjourney"
|
||||
MidjourneyFast = "midjourney-fast"
|
||||
MidjourneyTurbo = "midjourney-turbo"
|
||||
Hunyuan = "hunyuan"
|
||||
GPT360V9 = "360-gpt-v9"
|
||||
Baichuan53B = "baichuan-53b"
|
||||
SkylarkLite = "skylark-lite-public"
|
||||
SkylarkPlus = "skylark-plus-public"
|
||||
SkylarkPro = "skylark-pro-public"
|
||||
SkylarkChat = "skylark-chat"
|
||||
DeepseekV3 = "deepseek-chat"
|
||||
DeepseekR1 = "deepseek-reasoner"
|
||||
)
|
||||
|
||||
var OpenAIDalleModels = []string{
|
||||
Dalle, Dalle2, Dalle3,
|
||||
}
|
||||
|
||||
var GoogleImagenModels = []string{
|
||||
GoogleImagen002,
|
||||
}
|
||||
|
||||
var VisionModels = []string{
|
||||
GPT4VisionPreview, GPT41106VisionPreview, GPT4Turbo, GPT4Turbo20240409, GPT4O, GPT4O20240513, // openai
|
||||
GeminiProVision, Gemini15ProLatest, Gemini15FlashLatest, // gemini
|
||||
@ -166,6 +179,11 @@ func IsOpenAIDalleModel(model string) bool {
|
||||
return in(model, OpenAIDalleModels) && !strings.Contains(model, "gpt-4-dalle")
|
||||
}
|
||||
|
||||
func IsGoogleImagenModel(model string) bool {
|
||||
// using image generation api if model is in imagen models
|
||||
return in(model, GoogleImagenModels)
|
||||
}
|
||||
|
||||
func IsVisionModel(model string) bool {
|
||||
return in(model, VisionModels) && !in(model, VisionSkipModels)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user