mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
feat: Google Imagen image generation
This commit is contained in:
parent
8f0f9a0fda
commit
1856dd0312
@ -4,16 +4,22 @@ import (
|
|||||||
adaptercommon "chat/adapter/common"
|
adaptercommon "chat/adapter/common"
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var geminiMaxImages = 16
|
var geminiMaxImages = 16
|
||||||
|
|
||||||
func (c *ChatInstance) GetChatEndpoint(model string) string {
|
func (c *ChatInstance) GetChatEndpoint(model string, stream bool) string {
|
||||||
if model == globals.ChatBison001 {
|
if model == globals.ChatBison001 {
|
||||||
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
|
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stream {
|
||||||
|
return fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s", c.Endpoint, model, c.ApiKey)
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
|
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,7 +94,7 @@ func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
|
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
|
||||||
uri := c.GetChatEndpoint(props.Model)
|
uri := c.GetChatEndpoint(props.Model, false)
|
||||||
|
|
||||||
if props.Model == globals.ChatBison001 {
|
if props.Model == globals.ChatBison001 {
|
||||||
data, err := utils.Post(uri, map[string]string{
|
data, err := utils.Post(uri, map[string]string{
|
||||||
@ -112,9 +118,19 @@ func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string
|
|||||||
return c.GetGeminiChatResponse(data)
|
return c.GetGeminiChatResponse(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateStreamChatRequest is the mock stream request for palm2
|
// CreateStreamChatRequest is the stream request for palm2
|
||||||
// tips: palm2 does not support stream request
|
|
||||||
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error {
|
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error {
|
||||||
|
// Handle imagen models
|
||||||
|
if globals.IsGoogleImagenModel(props.Model) {
|
||||||
|
response, err := c.CreateImage(props)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return callback(&globals.Chunk{Content: response})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle chat models
|
||||||
|
if props.Model == globals.ChatBison001 {
|
||||||
response, err := c.CreateChatRequest(props)
|
response, err := c.CreateChatRequest(props)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -127,3 +143,64 @@ func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, c
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ticks := 0
|
||||||
|
scanErr := utils.EventScanner(&utils.EventScannerProps{
|
||||||
|
Method: "POST",
|
||||||
|
Uri: c.GetChatEndpoint(props.Model, true),
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
Body: c.GetGeminiChatBody(props),
|
||||||
|
Callback: func(data string) error {
|
||||||
|
ticks += 1
|
||||||
|
|
||||||
|
if form := utils.UnmarshalForm[GeminiStreamResponse](data); form != nil {
|
||||||
|
if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 {
|
||||||
|
return callback(&globals.Chunk{
|
||||||
|
Content: form.Candidates[0].Content.Parts[0].Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if form := utils.UnmarshalForm[GeminiChatErrorResponse](data); form != nil {
|
||||||
|
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, props.Proxy)
|
||||||
|
|
||||||
|
if scanErr != nil {
|
||||||
|
if scanErr.Error != nil && strings.Contains(scanErr.Error.Error(), "status code: 404") {
|
||||||
|
// downgrade to non-stream request
|
||||||
|
response, err := c.CreateChatRequest(props)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return callback(&globals.Chunk{Content: response})
|
||||||
|
}
|
||||||
|
|
||||||
|
if scanErr.Body != "" {
|
||||||
|
if form := utils.UnmarshalForm[GeminiChatErrorResponse](scanErr.Body); form != nil {
|
||||||
|
return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("gemini error: %s", scanErr.Body)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("gemini error: %v", scanErr.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ticks == 0 {
|
||||||
|
return errors.New("no response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string {
|
||||||
|
if len(props.Message) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return props.Message[len(props.Message)-1].Content
|
||||||
|
}
|
||||||
|
82
adapter/palm2/image.go
Normal file
82
adapter/palm2/image.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package palm2
|
||||||
|
|
||||||
|
import (
|
||||||
|
adaptercommon "chat/adapter/common"
|
||||||
|
"chat/globals"
|
||||||
|
"chat/utils"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageProps struct {
|
||||||
|
Model string
|
||||||
|
Prompt string
|
||||||
|
Proxy globals.ProxyConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetImageEndpoint(model string) string {
|
||||||
|
return fmt.Sprintf("%s/v1beta/models/%s:predict?key=%s", c.Endpoint, model, c.ApiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateImageRequest will create a gemini imagen from prompt, return base64 of image and error
|
||||||
|
func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
|
||||||
|
res, err := utils.Post(
|
||||||
|
c.GetImageEndpoint(props.Model),
|
||||||
|
map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
ImageRequest{
|
||||||
|
Instances: []ImageInstance{
|
||||||
|
{
|
||||||
|
Prompt: props.Prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Parameters: ImageParameters{
|
||||||
|
SampleCount: 1,
|
||||||
|
AspectRatio: "1:1",
|
||||||
|
PersonGeneration: "allow_adult",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
props.Proxy,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil || res == nil {
|
||||||
|
return "", fmt.Errorf("gemini error: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
data := utils.MapToStruct[ImageResponse](res)
|
||||||
|
if data == nil {
|
||||||
|
return "", fmt.Errorf("gemini error: cannot parse response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data.Predictions) == 0 {
|
||||||
|
return "", fmt.Errorf("gemini error: no image generated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return data.Predictions[0].BytesBase64Encoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateImage will create a gemini imagen from prompt, return markdown of image
|
||||||
|
func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) {
|
||||||
|
if !globals.IsGoogleImagenModel(props.Model) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
base64Data, err := c.CreateImageRequest(ImageProps{
|
||||||
|
Model: props.Model,
|
||||||
|
Prompt: c.GetLatestPrompt(props),
|
||||||
|
Proxy: props.Proxy,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "safety") {
|
||||||
|
return err.Error(), nil
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert base64 to data URL format
|
||||||
|
dataUrl := fmt.Sprintf("data:image/png;base64,%s", base64Data)
|
||||||
|
url := utils.StoreImage(dataUrl)
|
||||||
|
return utils.GetImageMarkdown(url), nil
|
||||||
|
}
|
@ -70,3 +70,40 @@ type GeminiChatErrorResponse struct {
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
} `json:"error"`
|
} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GeminiStreamResponse struct {
|
||||||
|
Candidates []struct {
|
||||||
|
Content struct {
|
||||||
|
Parts []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"parts"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
} `json:"content"`
|
||||||
|
} `json:"candidates"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageRequest is the native http request body for imagen
|
||||||
|
type ImageRequest struct {
|
||||||
|
Instances []ImageInstance `json:"instances"`
|
||||||
|
Parameters ImageParameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageInstance struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageParameters struct {
|
||||||
|
SampleCount int `json:"sampleCount,omitempty"`
|
||||||
|
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||||
|
PersonGeneration string `json:"personGeneration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponse is the native http response body for imagen
|
||||||
|
type ImageResponse struct {
|
||||||
|
Predictions []ImagePrediction `json:"predictions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImagePrediction struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||||
|
}
|
||||||
|
@ -109,6 +109,15 @@ const (
|
|||||||
GeminiProVision = "gemini-pro-vision"
|
GeminiProVision = "gemini-pro-vision"
|
||||||
Gemini15ProLatest = "gemini-1.5-pro-latest"
|
Gemini15ProLatest = "gemini-1.5-pro-latest"
|
||||||
Gemini15FlashLatest = "gemini-1.5-flash-latest"
|
Gemini15FlashLatest = "gemini-1.5-flash-latest"
|
||||||
|
Gemini20ProExp = "gemini-2.0-pro-exp-02-05"
|
||||||
|
Gemini20Flash = "gemini-2.0-flash"
|
||||||
|
Gemini20FlashExp = "gemini-2.0-flash-exp"
|
||||||
|
Gemini20Flash001 = "gemini-2.0-flash-001"
|
||||||
|
Gemini20FlashThinkingExp = "gemini-2.0-flash-thinking-exp-01-21"
|
||||||
|
Gemini20FlashLitePreview = "gemini-2.0-flash-lite-preview-02-05"
|
||||||
|
Gemini20FlashThinkingExp1219 = "gemini-2.0-flash-thinking-exp-1219"
|
||||||
|
GeminiExp1206 = "gemini-exp-1206"
|
||||||
|
GoogleImagen002 = "imagen-3.0-generate-002"
|
||||||
BingCreative = "bing-creative"
|
BingCreative = "bing-creative"
|
||||||
BingBalanced = "bing-balanced"
|
BingBalanced = "bing-balanced"
|
||||||
BingPrecise = "bing-precise"
|
BingPrecise = "bing-precise"
|
||||||
@ -141,6 +150,10 @@ var OpenAIDalleModels = []string{
|
|||||||
Dalle, Dalle2, Dalle3,
|
Dalle, Dalle2, Dalle3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var GoogleImagenModels = []string{
|
||||||
|
GoogleImagen002,
|
||||||
|
}
|
||||||
|
|
||||||
var VisionModels = []string{
|
var VisionModels = []string{
|
||||||
GPT4VisionPreview, GPT41106VisionPreview, GPT4Turbo, GPT4Turbo20240409, GPT4O, GPT4O20240513, // openai
|
GPT4VisionPreview, GPT41106VisionPreview, GPT4Turbo, GPT4Turbo20240409, GPT4O, GPT4O20240513, // openai
|
||||||
GeminiProVision, Gemini15ProLatest, Gemini15FlashLatest, // gemini
|
GeminiProVision, Gemini15ProLatest, Gemini15FlashLatest, // gemini
|
||||||
@ -166,6 +179,11 @@ func IsOpenAIDalleModel(model string) bool {
|
|||||||
return in(model, OpenAIDalleModels) && !strings.Contains(model, "gpt-4-dalle")
|
return in(model, OpenAIDalleModels) && !strings.Contains(model, "gpt-4-dalle")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsGoogleImagenModel(model string) bool {
|
||||||
|
// using image generation api if model is in imagen models
|
||||||
|
return in(model, GoogleImagenModels)
|
||||||
|
}
|
||||||
|
|
||||||
func IsVisionModel(model string) bool {
|
func IsVisionModel(model string) bool {
|
||||||
return in(model, VisionModels) && !in(model, VisionSkipModels)
|
return in(model, VisionModels) && !in(model, VisionSkipModels)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user