feat: Google Imagen image generation
Some checks failed
Build Test / release (18.x) (push) Has been cancelled
Docker Image CI / build (push) Has been cancelled

This commit is contained in:
Sh1n3zZ 2025-03-25 01:56:28 +08:00
parent 8f0f9a0fda
commit 1856dd0312
No known key found for this signature in database
GPG Key ID: 696702CF723B0452
4 changed files with 299 additions and 85 deletions

View File

@ -4,16 +4,22 @@ import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"errors"
"fmt"
"strings"
)
var geminiMaxImages = 16
func (c *ChatInstance) GetChatEndpoint(model string) string {
func (c *ChatInstance) GetChatEndpoint(model string, stream bool) string {
if model == globals.ChatBison001 {
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
}
if stream {
return fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s", c.Endpoint, model, c.ApiKey)
}
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
}
@ -88,7 +94,7 @@ func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) {
}
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
uri := c.GetChatEndpoint(props.Model)
uri := c.GetChatEndpoint(props.Model, false)
if props.Model == globals.ChatBison001 {
data, err := utils.Post(uri, map[string]string{
@ -112,18 +118,89 @@ func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string
return c.GetGeminiChatResponse(data)
}
// CreateStreamChatRequest is the mock stream request for palm2
// tips: palm2 does not support stream request
// CreateStreamChatRequest is the stream request for palm2
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error {
response, err := c.CreateChatRequest(props)
if err != nil {
return err
}
for _, item := range utils.SplitItem(response, " ") {
if err := callback(&globals.Chunk{Content: item}); err != nil {
// Handle imagen models
if globals.IsGoogleImagenModel(props.Model) {
response, err := c.CreateImage(props)
if err != nil {
return err
}
return callback(&globals.Chunk{Content: response})
}
// Handle chat models
if props.Model == globals.ChatBison001 {
response, err := c.CreateChatRequest(props)
if err != nil {
return err
}
for _, item := range utils.SplitItem(response, " ") {
if err := callback(&globals.Chunk{Content: item}); err != nil {
return err
}
}
return nil
}
ticks := 0
scanErr := utils.EventScanner(&utils.EventScannerProps{
Method: "POST",
Uri: c.GetChatEndpoint(props.Model, true),
Headers: map[string]string{
"Content-Type": "application/json",
},
Body: c.GetGeminiChatBody(props),
Callback: func(data string) error {
ticks += 1
if form := utils.UnmarshalForm[GeminiStreamResponse](data); form != nil {
if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 {
return callback(&globals.Chunk{
Content: form.Candidates[0].Content.Parts[0].Text,
})
}
return nil
}
if form := utils.UnmarshalForm[GeminiChatErrorResponse](data); form != nil {
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
}
return nil
},
}, props.Proxy)
if scanErr != nil {
if scanErr.Error != nil && strings.Contains(scanErr.Error.Error(), "status code: 404") {
// downgrade to non-stream request
response, err := c.CreateChatRequest(props)
if err != nil {
return err
}
return callback(&globals.Chunk{Content: response})
}
if scanErr.Body != "" {
if form := utils.UnmarshalForm[GeminiChatErrorResponse](scanErr.Body); form != nil {
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
}
return fmt.Errorf("gemini error: %s", scanErr.Body)
}
return fmt.Errorf("gemini error: %v", scanErr.Error)
}
if ticks == 0 {
return errors.New("no response")
}
return nil
}
func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string {
if len(props.Message) == 0 {
return ""
}
return props.Message[len(props.Message)-1].Content
}

82
adapter/palm2/image.go Normal file
View File

@ -0,0 +1,82 @@
package palm2
import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"fmt"
"strings"
)
type ImageProps struct {
Model string
Prompt string
Proxy globals.ProxyConfig
}
func (c *ChatInstance) GetImageEndpoint(model string) string {
return fmt.Sprintf("%s/v1beta/models/%s:predict?key=%s", c.Endpoint, model, c.ApiKey)
}
// CreateImageRequest will create a gemini imagen from prompt, return base64 of image and error
func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
res, err := utils.Post(
c.GetImageEndpoint(props.Model),
map[string]string{
"Content-Type": "application/json",
},
ImageRequest{
Instances: []ImageInstance{
{
Prompt: props.Prompt,
},
},
Parameters: ImageParameters{
SampleCount: 1,
AspectRatio: "1:1",
PersonGeneration: "allow_adult",
},
},
props.Proxy,
)
if err != nil || res == nil {
return "", fmt.Errorf("gemini error: %s", err.Error())
}
data := utils.MapToStruct[ImageResponse](res)
if data == nil {
return "", fmt.Errorf("gemini error: cannot parse response")
}
if len(data.Predictions) == 0 {
return "", fmt.Errorf("gemini error: no image generated")
}
return data.Predictions[0].BytesBase64Encoded, nil
}
// CreateImage will create a gemini imagen from prompt, return markdown of image
func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) {
if !globals.IsGoogleImagenModel(props.Model) {
return "", nil
}
base64Data, err := c.CreateImageRequest(ImageProps{
Model: props.Model,
Prompt: c.GetLatestPrompt(props),
Proxy: props.Proxy,
})
if err != nil {
if strings.Contains(err.Error(), "safety") {
return err.Error(), nil
}
return "", err
}
// Convert base64 to data URL format
dataUrl := fmt.Sprintf("data:image/png;base64,%s", base64Data)
url := utils.StoreImage(dataUrl)
return utils.GetImageMarkdown(url), nil
}

View File

@ -70,3 +70,40 @@ type GeminiChatErrorResponse struct {
Status string `json:"status"`
} `json:"error"`
}
type GeminiStreamResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
} `json:"candidates"`
}
// ImageRequest is the native http request body for imagen
type ImageRequest struct {
Instances []ImageInstance `json:"instances"`
Parameters ImageParameters `json:"parameters"`
}
type ImageInstance struct {
Prompt string `json:"prompt"`
}
type ImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
// ImageResponse is the native http response body for imagen
type ImageResponse struct {
Predictions []ImagePrediction `json:"predictions"`
}
type ImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
}

View File

@ -61,86 +61,99 @@ func OriginIsOpen(c *gin.Context) bool {
}
const (
GPT3Turbo = "gpt-3.5-turbo"
GPT3TurboInstruct = "gpt-3.5-turbo-instruct"
GPT3Turbo0613 = "gpt-3.5-turbo-0613"
GPT3Turbo0301 = "gpt-3.5-turbo-0301"
GPT3Turbo1106 = "gpt-3.5-turbo-1106"
GPT3Turbo0125 = "gpt-3.5-turbo-0125"
GPT3Turbo16k = "gpt-3.5-turbo-16k"
GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613"
GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301"
GPT4 = "gpt-4"
GPT4All = "gpt-4-all"
GPT4Vision = "gpt-4-v"
GPT4Dalle = "gpt-4-dalle"
GPT40314 = "gpt-4-0314"
GPT40613 = "gpt-4-0613"
GPT41106Preview = "gpt-4-1106-preview"
GPT40125Preview = "gpt-4-0125-preview"
GPT4TurboPreview = "gpt-4-turbo-preview"
GPT4VisionPreview = "gpt-4-vision-preview"
GPT4Turbo = "gpt-4-turbo"
GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09"
GPT41106VisionPreview = "gpt-4-1106-vision-preview"
GPT432k = "gpt-4-32k"
GPT432k0314 = "gpt-4-32k-0314"
GPT432k0613 = "gpt-4-32k-0613"
GPT4O = "gpt-4o"
GPT4O20240513 = "gpt-4o-2024-05-13"
Dalle = "dalle"
Dalle2 = "dall-e-2"
Dalle3 = "dall-e-3"
Claude1 = "claude-1"
Claude1100k = "claude-1.3"
Claude2 = "claude-1-100k"
Claude2100k = "claude-2"
Claude2200k = "claude-2.1"
Claude3 = "claude-3"
ClaudeSlack = "claude-slack"
SparkDeskLite = "spark-desk-lite"
SparkDeskPro = "spark-desk-pro"
SparkDeskPro128K = "spark-desk-pro-128k"
SparkDeskMax = "spark-desk-max"
SparkDeskMax32K = "spark-desk-max-32k"
SparkDeskV4Ultra = "spark-desk-4.0-ultra"
ChatBison001 = "chat-bison-001"
GeminiPro = "gemini-pro"
GeminiProVision = "gemini-pro-vision"
Gemini15ProLatest = "gemini-1.5-pro-latest"
Gemini15FlashLatest = "gemini-1.5-flash-latest"
BingCreative = "bing-creative"
BingBalanced = "bing-balanced"
BingPrecise = "bing-precise"
ZhiPuChatGLM4 = "glm-4"
ZhiPuChatGLM4Vision = "glm-4v"
ZhiPuChatGLM3Turbo = "glm-3-turbo"
ZhiPuChatGLMTurbo = "zhipu-chatglm-turbo"
ZhiPuChatGLMPro = "zhipu-chatglm-pro"
ZhiPuChatGLMStd = "zhipu-chatglm-std"
ZhiPuChatGLMLite = "zhipu-chatglm-lite"
QwenTurbo = "qwen-turbo"
QwenPlus = "qwen-plus"
QwenTurboNet = "qwen-turbo-net"
QwenPlusNet = "qwen-plus-net"
Midjourney = "midjourney"
MidjourneyFast = "midjourney-fast"
MidjourneyTurbo = "midjourney-turbo"
Hunyuan = "hunyuan"
GPT360V9 = "360-gpt-v9"
Baichuan53B = "baichuan-53b"
SkylarkLite = "skylark-lite-public"
SkylarkPlus = "skylark-plus-public"
SkylarkPro = "skylark-pro-public"
SkylarkChat = "skylark-chat"
DeepseekV3 = "deepseek-chat"
DeepseekR1 = "deepseek-reasoner"
GPT3Turbo = "gpt-3.5-turbo"
GPT3TurboInstruct = "gpt-3.5-turbo-instruct"
GPT3Turbo0613 = "gpt-3.5-turbo-0613"
GPT3Turbo0301 = "gpt-3.5-turbo-0301"
GPT3Turbo1106 = "gpt-3.5-turbo-1106"
GPT3Turbo0125 = "gpt-3.5-turbo-0125"
GPT3Turbo16k = "gpt-3.5-turbo-16k"
GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613"
GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301"
GPT4 = "gpt-4"
GPT4All = "gpt-4-all"
GPT4Vision = "gpt-4-v"
GPT4Dalle = "gpt-4-dalle"
GPT40314 = "gpt-4-0314"
GPT40613 = "gpt-4-0613"
GPT41106Preview = "gpt-4-1106-preview"
GPT40125Preview = "gpt-4-0125-preview"
GPT4TurboPreview = "gpt-4-turbo-preview"
GPT4VisionPreview = "gpt-4-vision-preview"
GPT4Turbo = "gpt-4-turbo"
GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09"
GPT41106VisionPreview = "gpt-4-1106-vision-preview"
GPT432k = "gpt-4-32k"
GPT432k0314 = "gpt-4-32k-0314"
GPT432k0613 = "gpt-4-32k-0613"
GPT4O = "gpt-4o"
GPT4O20240513 = "gpt-4o-2024-05-13"
Dalle = "dalle"
Dalle2 = "dall-e-2"
Dalle3 = "dall-e-3"
Claude1 = "claude-1"
Claude1100k = "claude-1.3"
Claude2 = "claude-1-100k"
Claude2100k = "claude-2"
Claude2200k = "claude-2.1"
Claude3 = "claude-3"
ClaudeSlack = "claude-slack"
SparkDeskLite = "spark-desk-lite"
SparkDeskPro = "spark-desk-pro"
SparkDeskPro128K = "spark-desk-pro-128k"
SparkDeskMax = "spark-desk-max"
SparkDeskMax32K = "spark-desk-max-32k"
SparkDeskV4Ultra = "spark-desk-4.0-ultra"
ChatBison001 = "chat-bison-001"
GeminiPro = "gemini-pro"
GeminiProVision = "gemini-pro-vision"
Gemini15ProLatest = "gemini-1.5-pro-latest"
Gemini15FlashLatest = "gemini-1.5-flash-latest"
Gemini20ProExp = "gemini-2.0-pro-exp-02-05"
Gemini20Flash = "gemini-2.0-flash"
Gemini20FlashExp = "gemini-2.0-flash-exp"
Gemini20Flash001 = "gemini-2.0-flash-001"
Gemini20FlashThinkingExp = "gemini-2.0-flash-thinking-exp-01-21"
Gemini20FlashLitePreview = "gemini-2.0-flash-lite-preview-02-05"
Gemini20FlashThinkingExp1219 = "gemini-2.0-flash-thinking-exp-1219"
GeminiExp1206 = "gemini-exp-1206"
GoogleImagen002 = "imagen-3.0-generate-002"
BingCreative = "bing-creative"
BingBalanced = "bing-balanced"
BingPrecise = "bing-precise"
ZhiPuChatGLM4 = "glm-4"
ZhiPuChatGLM4Vision = "glm-4v"
ZhiPuChatGLM3Turbo = "glm-3-turbo"
ZhiPuChatGLMTurbo = "zhipu-chatglm-turbo"
ZhiPuChatGLMPro = "zhipu-chatglm-pro"
ZhiPuChatGLMStd = "zhipu-chatglm-std"
ZhiPuChatGLMLite = "zhipu-chatglm-lite"
QwenTurbo = "qwen-turbo"
QwenPlus = "qwen-plus"
QwenTurboNet = "qwen-turbo-net"
QwenPlusNet = "qwen-plus-net"
Midjourney = "midjourney"
MidjourneyFast = "midjourney-fast"
MidjourneyTurbo = "midjourney-turbo"
Hunyuan = "hunyuan"
GPT360V9 = "360-gpt-v9"
Baichuan53B = "baichuan-53b"
SkylarkLite = "skylark-lite-public"
SkylarkPlus = "skylark-plus-public"
SkylarkPro = "skylark-pro-public"
SkylarkChat = "skylark-chat"
DeepseekV3 = "deepseek-chat"
DeepseekR1 = "deepseek-reasoner"
)
var OpenAIDalleModels = []string{
Dalle, Dalle2, Dalle3,
}
var GoogleImagenModels = []string{
GoogleImagen002,
}
var VisionModels = []string{
GPT4VisionPreview, GPT41106VisionPreview, GPT4Turbo, GPT4Turbo20240409, GPT4O, GPT4O20240513, // openai
GeminiProVision, Gemini15ProLatest, Gemini15FlashLatest, // gemini
@ -166,6 +179,11 @@ func IsOpenAIDalleModel(model string) bool {
return in(model, OpenAIDalleModels) && !strings.Contains(model, "gpt-4-dalle")
}
func IsGoogleImagenModel(model string) bool {
// using image generation api if model is in imagen models
return in(model, GoogleImagenModels)
}
func IsVisionModel(model string) bool {
return in(model, VisionModels) && !in(model, VisionSkipModels)
}