diff --git a/adapter/palm2/chat.go b/adapter/palm2/chat.go index 09255d5..843f06c 100644 --- a/adapter/palm2/chat.go +++ b/adapter/palm2/chat.go @@ -4,16 +4,22 @@ import ( adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" + "errors" "fmt" + "strings" ) var geminiMaxImages = 16 -func (c *ChatInstance) GetChatEndpoint(model string) string { +func (c *ChatInstance) GetChatEndpoint(model string, stream bool) string { if model == globals.ChatBison001 { return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey) } + if stream { + return fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s", c.Endpoint, model, c.ApiKey) + } + return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey) } @@ -88,7 +94,7 @@ func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) { } func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { - uri := c.GetChatEndpoint(props.Model) + uri := c.GetChatEndpoint(props.Model, false) if props.Model == globals.ChatBison001 { data, err := utils.Post(uri, map[string]string{ @@ -112,18 +118,89 @@ func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string return c.GetGeminiChatResponse(data) } -// CreateStreamChatRequest is the mock stream request for palm2 -// tips: palm2 does not support stream request +// CreateStreamChatRequest is the stream request for palm2 func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { - response, err := c.CreateChatRequest(props) - if err != nil { - return err - } - - for _, item := range utils.SplitItem(response, " ") { - if err := callback(&globals.Chunk{Content: item}); err != nil { + // Handle imagen models + if globals.IsGoogleImagenModel(props.Model) { + response, err := c.CreateImage(props) + if err != nil { return err } + return callback(&globals.Chunk{Content: response}) } + + // Handle chat models + if props.Model == globals.ChatBison001 { + response, err := c.CreateChatRequest(props) + if err != nil { + return err + } + + for _, item := range utils.SplitItem(response, " ") { + if err := callback(&globals.Chunk{Content: item}); err != nil { + return err + } + } + return nil + } + + ticks := 0 + scanErr := utils.EventScanner(&utils.EventScannerProps{ + Method: "POST", + Uri: c.GetChatEndpoint(props.Model, true), + Headers: map[string]string{ + "Content-Type": "application/json", + }, + Body: c.GetGeminiChatBody(props), + Callback: func(data string) error { + ticks += 1 + + if form := utils.UnmarshalForm[GeminiStreamResponse](data); form != nil { + if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 { + return callback(&globals.Chunk{ + Content: form.Candidates[0].Content.Parts[0].Text, + }) + } + return nil + } + + if form := utils.UnmarshalForm[GeminiChatErrorResponse](data); form != nil { + return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status) + } + + return nil + }, + }, props.Proxy) + + if scanErr != nil { + if scanErr.Error != nil && strings.Contains(scanErr.Error.Error(), "status code: 404") { + // downgrade to non-stream request + response, err := c.CreateChatRequest(props) + if err != nil { + return err + } + return callback(&globals.Chunk{Content: response}) + } + + if scanErr.Body != "" { + if form := utils.UnmarshalForm[GeminiChatErrorResponse](scanErr.Body); form != nil { + return fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status) + } + return fmt.Errorf("gemini error: %s", scanErr.Body) + } + return fmt.Errorf("gemini error: %v", scanErr.Error) + } + + if ticks == 0 { + return errors.New("no response") + } + return nil } + +func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string { + if len(props.Message) == 0 { + return "" + } + return props.Message[len(props.Message)-1].Content +} diff --git a/adapter/palm2/image.go b/adapter/palm2/image.go new file mode 100644 index 0000000..7fe379a --- /dev/null +++ b/adapter/palm2/image.go @@ -0,0 +1,82 @@ +package palm2 + +import ( + adaptercommon "chat/adapter/common" + "chat/globals" + "chat/utils" + "fmt" + "strings" +) + +type ImageProps struct { + Model string + Prompt string + Proxy globals.ProxyConfig +} + +func (c *ChatInstance) GetImageEndpoint(model string) string { + return fmt.Sprintf("%s/v1beta/models/%s:predict?key=%s", c.Endpoint, model, c.ApiKey) +} + +// CreateImageRequest will create a gemini imagen from prompt, return base64 of image and error +func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) { + res, err := utils.Post( + c.GetImageEndpoint(props.Model), + map[string]string{ + "Content-Type": "application/json", + }, + ImageRequest{ + Instances: []ImageInstance{ + { + Prompt: props.Prompt, + }, + }, + Parameters: ImageParameters{ + SampleCount: 1, + AspectRatio: "1:1", + PersonGeneration: "allow_adult", + }, + }, + props.Proxy, + ) + + if err != nil || res == nil { + return "", fmt.Errorf("gemini error: %s", err.Error()) + } + + data := utils.MapToStruct[ImageResponse](res) + if data == nil { + return "", fmt.Errorf("gemini error: cannot parse response") + } + + if len(data.Predictions) == 0 { + return "", fmt.Errorf("gemini error: no image generated") + } + + return data.Predictions[0].BytesBase64Encoded, nil +} + +// CreateImage will create a gemini imagen from prompt, return markdown of image +func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) { + if !globals.IsGoogleImagenModel(props.Model) { + return "", nil + } + + base64Data, err := c.CreateImageRequest(ImageProps{ + Model: props.Model, + Prompt: c.GetLatestPrompt(props), + Proxy: props.Proxy, + }) + + if err != nil { + if strings.Contains(err.Error(), "safety") { + return err.Error(), nil + } + return "", err + } + + // Convert base64 to data URL format + dataUrl := fmt.Sprintf("data:image/png;base64,%s", base64Data) + url := utils.StoreImage(dataUrl) + return utils.GetImageMarkdown(url), nil +} diff --git a/adapter/palm2/types.go b/adapter/palm2/types.go index 7a41551..549cdfb 100644 --- a/adapter/palm2/types.go +++ b/adapter/palm2/types.go @@ -70,3 +70,40 @@ type GeminiChatErrorResponse struct { Status string `json:"status"` } `json:"error"` } + +type GeminiStreamResponse struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + Role string `json:"role"` + } `json:"content"` + } `json:"candidates"` +} + +// ImageRequest is the native http request body for imagen +type ImageRequest struct { + Instances []ImageInstance `json:"instances"` + Parameters ImageParameters `json:"parameters"` +} + +type ImageInstance struct { + Prompt string `json:"prompt"` +} + +type ImageParameters struct { + SampleCount int `json:"sampleCount,omitempty"` + AspectRatio string `json:"aspectRatio,omitempty"` + PersonGeneration string `json:"personGeneration,omitempty"` +} + +// ImageResponse is the native http response body for imagen +type ImageResponse struct { + Predictions []ImagePrediction `json:"predictions"` +} + +type ImagePrediction struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` +} diff --git a/globals/variables.go b/globals/variables.go index 32c34f4..9461c32 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -61,86 +61,99 @@ func OriginIsOpen(c *gin.Context) bool { } const ( - GPT3Turbo = "gpt-3.5-turbo" - GPT3TurboInstruct = "gpt-3.5-turbo-instruct" - GPT3Turbo0613 = "gpt-3.5-turbo-0613" - GPT3Turbo0301 = "gpt-3.5-turbo-0301" - GPT3Turbo1106 = "gpt-3.5-turbo-1106" - GPT3Turbo0125 = "gpt-3.5-turbo-0125" - GPT3Turbo16k = "gpt-3.5-turbo-16k" - GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613" - GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301" - GPT4 = "gpt-4" - GPT4All = "gpt-4-all" - GPT4Vision = "gpt-4-v" - GPT4Dalle = "gpt-4-dalle" - GPT40314 = "gpt-4-0314" - GPT40613 = "gpt-4-0613" - GPT41106Preview = "gpt-4-1106-preview" - GPT40125Preview = "gpt-4-0125-preview" - GPT4TurboPreview = "gpt-4-turbo-preview" - GPT4VisionPreview = "gpt-4-vision-preview" - GPT4Turbo = "gpt-4-turbo" - GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" - GPT41106VisionPreview = "gpt-4-1106-vision-preview" - GPT432k = "gpt-4-32k" - GPT432k0314 = "gpt-4-32k-0314" - GPT432k0613 = "gpt-4-32k-0613" - GPT4O = "gpt-4o" - GPT4O20240513 = "gpt-4o-2024-05-13" - Dalle = "dalle" - Dalle2 = "dall-e-2" - Dalle3 = "dall-e-3" - Claude1 = "claude-1" - Claude1100k = "claude-1.3" - Claude2 = "claude-1-100k" - Claude2100k = "claude-2" - Claude2200k = "claude-2.1" - Claude3 = "claude-3" - ClaudeSlack = "claude-slack" - SparkDeskLite = "spark-desk-lite" - SparkDeskPro = "spark-desk-pro" - SparkDeskPro128K = "spark-desk-pro-128k" - SparkDeskMax = "spark-desk-max" - SparkDeskMax32K = "spark-desk-max-32k" - SparkDeskV4Ultra = "spark-desk-4.0-ultra" - ChatBison001 = "chat-bison-001" - GeminiPro = "gemini-pro" - GeminiProVision = "gemini-pro-vision" - Gemini15ProLatest = "gemini-1.5-pro-latest" - Gemini15FlashLatest = "gemini-1.5-flash-latest" - BingCreative = "bing-creative" - BingBalanced = "bing-balanced" - BingPrecise = "bing-precise" - ZhiPuChatGLM4 = "glm-4" - ZhiPuChatGLM4Vision = "glm-4v" - ZhiPuChatGLM3Turbo = "glm-3-turbo" - ZhiPuChatGLMTurbo = "zhipu-chatglm-turbo" - ZhiPuChatGLMPro = "zhipu-chatglm-pro" - ZhiPuChatGLMStd = "zhipu-chatglm-std" - ZhiPuChatGLMLite = "zhipu-chatglm-lite" - QwenTurbo = "qwen-turbo" - QwenPlus = "qwen-plus" - QwenTurboNet = "qwen-turbo-net" - QwenPlusNet = "qwen-plus-net" - Midjourney = "midjourney" - MidjourneyFast = "midjourney-fast" - MidjourneyTurbo = "midjourney-turbo" - Hunyuan = "hunyuan" - GPT360V9 = "360-gpt-v9" - Baichuan53B = "baichuan-53b" - SkylarkLite = "skylark-lite-public" - SkylarkPlus = "skylark-plus-public" - SkylarkPro = "skylark-pro-public" - SkylarkChat = "skylark-chat" - DeepseekV3 = "deepseek-chat" - DeepseekR1 = "deepseek-reasoner" + GPT3Turbo = "gpt-3.5-turbo" + GPT3TurboInstruct = "gpt-3.5-turbo-instruct" + GPT3Turbo0613 = "gpt-3.5-turbo-0613" + GPT3Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Turbo1106 = "gpt-3.5-turbo-1106" + GPT3Turbo0125 = "gpt-3.5-turbo-0125" + GPT3Turbo16k = "gpt-3.5-turbo-16k" + GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613" + GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301" + GPT4 = "gpt-4" + GPT4All = "gpt-4-all" + GPT4Vision = "gpt-4-v" + GPT4Dalle = "gpt-4-dalle" + GPT40314 = "gpt-4-0314" + GPT40613 = "gpt-4-0613" + GPT41106Preview = "gpt-4-1106-preview" + GPT40125Preview = "gpt-4-0125-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" + GPT4VisionPreview = "gpt-4-vision-preview" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" + GPT41106VisionPreview = "gpt-4-1106-vision-preview" + GPT432k = "gpt-4-32k" + GPT432k0314 = "gpt-4-32k-0314" + GPT432k0613 = "gpt-4-32k-0613" + GPT4O = "gpt-4o" + GPT4O20240513 = "gpt-4o-2024-05-13" + Dalle = "dalle" + Dalle2 = "dall-e-2" + Dalle3 = "dall-e-3" + Claude1 = "claude-1" + Claude1100k = "claude-1.3" + Claude2 = "claude-1-100k" + Claude2100k = "claude-2" + Claude2200k = "claude-2.1" + Claude3 = "claude-3" + ClaudeSlack = "claude-slack" + SparkDeskLite = "spark-desk-lite" + SparkDeskPro = "spark-desk-pro" + SparkDeskPro128K = "spark-desk-pro-128k" + SparkDeskMax = "spark-desk-max" + SparkDeskMax32K = "spark-desk-max-32k" + SparkDeskV4Ultra = "spark-desk-4.0-ultra" + ChatBison001 = "chat-bison-001" + GeminiPro = "gemini-pro" + GeminiProVision = "gemini-pro-vision" + Gemini15ProLatest = "gemini-1.5-pro-latest" + Gemini15FlashLatest = "gemini-1.5-flash-latest" + Gemini20ProExp = "gemini-2.0-pro-exp-02-05" + Gemini20Flash = "gemini-2.0-flash" + Gemini20FlashExp = "gemini-2.0-flash-exp" + Gemini20Flash001 = "gemini-2.0-flash-001" + Gemini20FlashThinkingExp = "gemini-2.0-flash-thinking-exp-01-21" + Gemini20FlashLitePreview = "gemini-2.0-flash-lite-preview-02-05" + Gemini20FlashThinkingExp1219 = "gemini-2.0-flash-thinking-exp-1219" + GeminiExp1206 = "gemini-exp-1206" + GoogleImagen002 = "imagen-3.0-generate-002" + BingCreative = "bing-creative" + BingBalanced = "bing-balanced" + BingPrecise = "bing-precise" + ZhiPuChatGLM4 = "glm-4" + ZhiPuChatGLM4Vision = "glm-4v" + ZhiPuChatGLM3Turbo = "glm-3-turbo" + ZhiPuChatGLMTurbo = "zhipu-chatglm-turbo" + ZhiPuChatGLMPro = "zhipu-chatglm-pro" + ZhiPuChatGLMStd = "zhipu-chatglm-std" + ZhiPuChatGLMLite = "zhipu-chatglm-lite" + QwenTurbo = "qwen-turbo" + QwenPlus = "qwen-plus" + QwenTurboNet = "qwen-turbo-net" + QwenPlusNet = "qwen-plus-net" + Midjourney = "midjourney" + MidjourneyFast = "midjourney-fast" + MidjourneyTurbo = "midjourney-turbo" + Hunyuan = "hunyuan" + GPT360V9 = "360-gpt-v9" + Baichuan53B = "baichuan-53b" + SkylarkLite = "skylark-lite-public" + SkylarkPlus = "skylark-plus-public" + SkylarkPro = "skylark-pro-public" + SkylarkChat = "skylark-chat" + DeepseekV3 = "deepseek-chat" + DeepseekR1 = "deepseek-reasoner" ) var OpenAIDalleModels = []string{ Dalle, Dalle2, Dalle3, } +var GoogleImagenModels = []string{ + GoogleImagen002, +} + var VisionModels = []string{ GPT4VisionPreview, GPT41106VisionPreview, GPT4Turbo, GPT4Turbo20240409, GPT4O, GPT4O20240513, // openai GeminiProVision, Gemini15ProLatest, Gemini15FlashLatest, // gemini @@ -166,6 +179,11 @@ func IsOpenAIDalleModel(model string) bool { return in(model, OpenAIDalleModels) && !strings.Contains(model, "gpt-4-dalle") } +func IsGoogleImagenModel(model string) bool { + // using image generation api if model is in imagen models + return in(model, GoogleImagenModels) +} + func IsVisionModel(model string) bool { return in(model, VisionModels) && !in(model, VisionSkipModels) }