feat: Google Imagen image generation
Some checks failed
Build Test / release (18.x) (push) Has been cancelled
Docker Image CI / build (push) Has been cancelled

This commit is contained in:
Sh1n3zZ 2025-03-25 01:56:28 +08:00
parent 8f0f9a0fda
commit 1856dd0312
No known key found for this signature in database
GPG Key ID: 696702CF723B0452
4 changed files with 299 additions and 85 deletions

View File

@ -4,16 +4,22 @@ import (
adaptercommon "chat/adapter/common" adaptercommon "chat/adapter/common"
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"errors"
"fmt" "fmt"
"strings"
) )
var geminiMaxImages = 16 var geminiMaxImages = 16
func (c *ChatInstance) GetChatEndpoint(model string) string { func (c *ChatInstance) GetChatEndpoint(model string, stream bool) string {
if model == globals.ChatBison001 { if model == globals.ChatBison001 {
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey) return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
} }
if stream {
return fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s", c.Endpoint, model, c.ApiKey)
}
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey) return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
} }
@ -88,7 +94,7 @@ func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) {
} }
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
uri := c.GetChatEndpoint(props.Model) uri := c.GetChatEndpoint(props.Model, false)
if props.Model == globals.ChatBison001 { if props.Model == globals.ChatBison001 {
data, err := utils.Post(uri, map[string]string{ data, err := utils.Post(uri, map[string]string{
@ -112,9 +118,19 @@ func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string
return c.GetGeminiChatResponse(data) return c.GetGeminiChatResponse(data)
} }
// CreateStreamChatRequest is the mock stream request for palm2 // CreateStreamChatRequest is the stream request for palm2
// tips: palm2 does not support stream request
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error {
// Handle imagen models
if globals.IsGoogleImagenModel(props.Model) {
response, err := c.CreateImage(props)
if err != nil {
return err
}
return callback(&globals.Chunk{Content: response})
}
// Handle chat models
if props.Model == globals.ChatBison001 {
response, err := c.CreateChatRequest(props) response, err := c.CreateChatRequest(props)
if err != nil { if err != nil {
return err return err
@ -127,3 +143,64 @@ func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, c
} }
return nil return nil
} }
ticks := 0
scanErr := utils.EventScanner(&utils.EventScannerProps{
Method: "POST",
Uri: c.GetChatEndpoint(props.Model, true),
Headers: map[string]string{
"Content-Type": "application/json",
},
Body: c.GetGeminiChatBody(props),
Callback: func(data string) error {
ticks += 1
if form := utils.UnmarshalForm[GeminiStreamResponse](data); form != nil {
if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 {
return callback(&globals.Chunk{
Content: form.Candidates[0].Content.Parts[0].Text,
})
}
return nil
}
if form := utils.UnmarshalForm[GeminiChatErrorResponse](data); form != nil {
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
}
return nil
},
}, props.Proxy)
if scanErr != nil {
if scanErr.Error != nil && strings.Contains(scanErr.Error.Error(), "status code: 404") {
// downgrade to non-stream request
response, err := c.CreateChatRequest(props)
if err != nil {
return err
}
return callback(&globals.Chunk{Content: response})
}
if scanErr.Body != "" {
if form := utils.UnmarshalForm[GeminiChatErrorResponse](scanErr.Body); form != nil {
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
}
return fmt.Errorf("gemini error: %s", scanErr.Body)
}
return fmt.Errorf("gemini error: %v", scanErr.Error)
}
if ticks == 0 {
return errors.New("no response")
}
return nil
}
func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string {
if len(props.Message) == 0 {
return ""
}
return props.Message[len(props.Message)-1].Content
}

82
adapter/palm2/image.go Normal file
View File

@ -0,0 +1,82 @@
package palm2
import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"fmt"
"strings"
)
type ImageProps struct {
Model string
Prompt string
Proxy globals.ProxyConfig
}
func (c *ChatInstance) GetImageEndpoint(model string) string {
return fmt.Sprintf("%s/v1beta/models/%s:predict?key=%s", c.Endpoint, model, c.ApiKey)
}
// CreateImageRequest will create a gemini imagen from prompt, return base64 of image and error
func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
res, err := utils.Post(
c.GetImageEndpoint(props.Model),
map[string]string{
"Content-Type": "application/json",
},
ImageRequest{
Instances: []ImageInstance{
{
Prompt: props.Prompt,
},
},
Parameters: ImageParameters{
SampleCount: 1,
AspectRatio: "1:1",
PersonGeneration: "allow_adult",
},
},
props.Proxy,
)
if err != nil || res == nil {
return "", fmt.Errorf("gemini error: %s", err.Error())
}
data := utils.MapToStruct[ImageResponse](res)
if data == nil {
return "", fmt.Errorf("gemini error: cannot parse response")
}
if len(data.Predictions) == 0 {
return "", fmt.Errorf("gemini error: no image generated")
}
return data.Predictions[0].BytesBase64Encoded, nil
}
// CreateImage will create a gemini imagen from prompt, return markdown of image
func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) {
if !globals.IsGoogleImagenModel(props.Model) {
return "", nil
}
base64Data, err := c.CreateImageRequest(ImageProps{
Model: props.Model,
Prompt: c.GetLatestPrompt(props),
Proxy: props.Proxy,
})
if err != nil {
if strings.Contains(err.Error(), "safety") {
return err.Error(), nil
}
return "", err
}
// Convert base64 to data URL format
dataUrl := fmt.Sprintf("data:image/png;base64,%s", base64Data)
url := utils.StoreImage(dataUrl)
return utils.GetImageMarkdown(url), nil
}

View File

@ -70,3 +70,40 @@ type GeminiChatErrorResponse struct {
Status string `json:"status"` Status string `json:"status"`
} `json:"error"` } `json:"error"`
} }
type GeminiStreamResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
} `json:"candidates"`
}
// ImageRequest is the native http request body for imagen
type ImageRequest struct {
Instances []ImageInstance `json:"instances"`
Parameters ImageParameters `json:"parameters"`
}
type ImageInstance struct {
Prompt string `json:"prompt"`
}
type ImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
// ImageResponse is the native http response body for imagen
type ImageResponse struct {
Predictions []ImagePrediction `json:"predictions"`
}
type ImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
}

View File

@ -109,6 +109,15 @@ const (
GeminiProVision = "gemini-pro-vision" GeminiProVision = "gemini-pro-vision"
Gemini15ProLatest = "gemini-1.5-pro-latest" Gemini15ProLatest = "gemini-1.5-pro-latest"
Gemini15FlashLatest = "gemini-1.5-flash-latest" Gemini15FlashLatest = "gemini-1.5-flash-latest"
Gemini20ProExp = "gemini-2.0-pro-exp-02-05"
Gemini20Flash = "gemini-2.0-flash"
Gemini20FlashExp = "gemini-2.0-flash-exp"
Gemini20Flash001 = "gemini-2.0-flash-001"
Gemini20FlashThinkingExp = "gemini-2.0-flash-thinking-exp-01-21"
Gemini20FlashLitePreview = "gemini-2.0-flash-lite-preview-02-05"
Gemini20FlashThinkingExp1219 = "gemini-2.0-flash-thinking-exp-1219"
GeminiExp1206 = "gemini-exp-1206"
GoogleImagen002 = "imagen-3.0-generate-002"
BingCreative = "bing-creative" BingCreative = "bing-creative"
BingBalanced = "bing-balanced" BingBalanced = "bing-balanced"
BingPrecise = "bing-precise" BingPrecise = "bing-precise"
@ -141,6 +150,10 @@ var OpenAIDalleModels = []string{
Dalle, Dalle2, Dalle3, Dalle, Dalle2, Dalle3,
} }
var GoogleImagenModels = []string{
GoogleImagen002,
}
var VisionModels = []string{ var VisionModels = []string{
GPT4VisionPreview, GPT41106VisionPreview, GPT4Turbo, GPT4Turbo20240409, GPT4O, GPT4O20240513, // openai GPT4VisionPreview, GPT41106VisionPreview, GPT4Turbo, GPT4Turbo20240409, GPT4O, GPT4O20240513, // openai
GeminiProVision, Gemini15ProLatest, Gemini15FlashLatest, // gemini GeminiProVision, Gemini15ProLatest, Gemini15FlashLatest, // gemini
@ -166,6 +179,11 @@ func IsOpenAIDalleModel(model string) bool {
return in(model, OpenAIDalleModels) && !strings.Contains(model, "gpt-4-dalle") return in(model, OpenAIDalleModels) && !strings.Contains(model, "gpt-4-dalle")
} }
func IsGoogleImagenModel(model string) bool {
// using image generation api if model is in imagen models
return in(model, GoogleImagenModels)
}
func IsVisionModel(model string) bool { func IsVisionModel(model string) bool {
return in(model, VisionModels) && !in(model, VisionSkipModels) return in(model, VisionModels) && !in(model, VisionSkipModels)
} }