diff --git a/adapter/palm2/chat.go b/adapter/palm2/chat.go index 7e1fc03..703db09 100644 --- a/adapter/palm2/chat.go +++ b/adapter/palm2/chat.go @@ -6,13 +6,23 @@ import ( "fmt" ) +var geminiMaxImages = 16 + type ChatProps struct { - Model string - Message []globals.Message + Model string + Message []globals.Message + Temperature *float64 + TopP *float64 + TopK *int + MaxOutputTokens *int } func (c *ChatInstance) GetChatEndpoint(model string) string { - return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey) + if model == globals.ChatBison001 { + return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey) + } + + return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey) } func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage { @@ -41,33 +51,75 @@ func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage { return result } -func (c *ChatInstance) GetChatBody(props *ChatProps) *ChatBody { - return &ChatBody{ - Prompt: Prompt{ +func (c *ChatInstance) GetPalm2ChatBody(props *ChatProps) *PalmChatBody { + return &PalmChatBody{ + Prompt: PalmPrompt{ Messages: c.ConvertMessage(props.Message), }, } } -func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { - uri := c.GetChatEndpoint(props.Model) - data, err := utils.Post(uri, map[string]string{ - "Content-Type": "application/json", - }, c.GetChatBody(props)) - - if err != nil { - return "", fmt.Errorf("palm2 error: %s", err.Error()) +func (c *ChatInstance) GetGeminiChatBody(props *ChatProps) *GeminiChatBody { + return &GeminiChatBody{ + Contents: c.GetGeminiContents(props.Model, props.Message), + GenerationConfig: GeminiConfig{ + Temperature: props.Temperature, + MaxOutputTokens: props.MaxOutputTokens, + TopP: props.TopP, + TopK: props.TopK, + }, } +} - if form := utils.MapToStruct[ChatResponse](data); form != nil { +func (c *ChatInstance) GetPalm2ChatResponse(data interface{}) (string, error) { + if form := utils.MapToStruct[PalmChatResponse](data); form != nil { if len(form.Candidates) == 0 { - return "I don't know how to respond to that. Please try another question.", nil + return "", fmt.Errorf("palm2 error: the content violates content policy") } return form.Candidates[0].Content, nil } return "", fmt.Errorf("palm2 error: cannot parse response") } +func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) { + if form := utils.MapToStruct[GeminiChatResponse](data); form != nil { + if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 { + return form.Candidates[0].Content.Parts[0].Text, nil + } + } + + if form := utils.MapToStruct[GeminiChatErrorResponse](data); form != nil { + return "", fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status) + } + + return "", fmt.Errorf("gemini: cannot parse response") +} + +func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { + uri := c.GetChatEndpoint(props.Model) + + if props.Model == globals.ChatBison001 { + data, err := utils.Post(uri, map[string]string{ + "Content-Type": "application/json", + }, c.GetPalm2ChatBody(props)) + + if err != nil { + return "", fmt.Errorf("palm2 error: %s", err.Error()) + } + return c.GetPalm2ChatResponse(data) + } + + data, err := utils.Post(uri, map[string]string{ + "Content-Type": "application/json", + }, c.GetGeminiChatBody(props)) + + if err != nil { + return "", fmt.Errorf("gemini error: %s", err.Error()) + } + + return c.GetGeminiChatResponse(data) +} + // CreateStreamChatRequest is the mock stream request for palm2 // tips: palm2 does not support stream request func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { diff --git a/adapter/palm2/formatter.go b/adapter/palm2/formatter.go new file mode 100644 index 0000000..cd15db3 --- /dev/null +++ b/adapter/palm2/formatter.go @@ -0,0 +1,106 @@ +package palm2 + +import ( + "chat/globals" + "chat/utils" + "strings" +) + +func getGeminiRole(role string) string { + switch role { + case globals.User: + return GeminiUserType + case globals.Assistant, globals.Tool, globals.System: + return GeminiModelType + default: + return GeminiUserType + } +} + +func getMimeType(content string) string { + segment := strings.Split(content, ".") + if len(segment) == 0 || len(segment) == 1 { + return "image/png" + } + + suffix := strings.TrimSpace(strings.ToLower(segment[len(segment)-1])) + + switch suffix { + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "gif": + return "image/gif" + case "webp": + return "image/webp" + case "heif": + return "image/heif" + case "heic": + return "image/heic" + default: + return "image/png" + } +} + +func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart { + parts = append(parts, GeminiChatPart{ + Text: &content, + }) + + if model == globals.GeminiPro { + return parts + } + + urls := utils.ExtractImageUrls(content) + if len(urls) > geminiMaxImages { + urls = urls[:geminiMaxImages] + } + + for _, url := range urls { + data, err := utils.ConvertToBase64(url) + if err != nil { + continue + } + + parts = append(parts, GeminiChatPart{ + InlineData: &GeminiInlineData{ + MimeType: getMimeType(url), + Data: data, + }, + }) + } + + return parts +} + +func (c *ChatInstance) GetGeminiContents(model string, message []globals.Message) []GeminiContent { + // gemini role should be user-model + + result := make([]GeminiContent, 0) + for _, item := range message { + role := getGeminiRole(item.Role) + if len(item.Content) == 0 { + // gemini model: message must include non empty content + continue + } + + if len(result) == 0 && getGeminiRole(item.Role) == GeminiModelType { + // gemini model: first message must be user + continue + } + + if len(result) > 0 && role == result[len(result)-1].Role { + // gemini model: messages must alternate between authors + result[len(result)-1].Parts = getGeminiContent(result[len(result)-1].Parts, item.Content, model) + continue + } + + result = append(result, GeminiContent{ + Role: getGeminiRole(item.Role), + Parts: getGeminiContent(make([]GeminiChatPart, 0), item.Content, model), + }) + } + + return result +} diff --git a/adapter/palm2/types.go b/adapter/palm2/types.go index fa17199..c3e1ee8 100644 --- a/adapter/palm2/types.go +++ b/adapter/palm2/types.go @@ -1,20 +1,72 @@ package palm2 +const ( + GeminiUserType = "user" + GeminiModelType = "model" +) + type PalmMessage struct { Author string `json:"author"` Content string `json:"content"` } -// ChatBody is the native http request body for palm2 -type ChatBody struct { - Prompt Prompt `json:"prompt"` +// PalmChatBody is the native http request body for palm2 +type PalmChatBody struct { + Prompt PalmPrompt `json:"prompt"` } -type Prompt struct { +type PalmPrompt struct { Messages []PalmMessage `json:"messages"` } -// ChatResponse is the native http response body for palm2 -type ChatResponse struct { +// PalmChatResponse is the native http response body for palm2 +type PalmChatResponse struct { Candidates []PalmMessage `json:"candidates"` } + +// GeminiChatBody is the native http request body for gemini +type GeminiChatBody struct { + Contents []GeminiContent `json:"contents"` + GenerationConfig GeminiConfig `json:"generationConfig"` +} + +type GeminiConfig struct { + Temperature *float64 `json:"temperature,omitempty"` + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK *int `json:"topK,omitempty"` +} + +type GeminiContent struct { + Role string `json:"role"` + Parts []GeminiChatPart `json:"parts"` +} + +type GeminiChatPart struct { + Text *string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inline_data,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mime_type"` + Data string `json:"data"` +} + +type GeminiChatResponse struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + Role string `json:"role"` + } `json:"content"` + } `json:"candidates"` +} + +type GeminiChatErrorResponse struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + } `json:"error"` +} diff --git a/app/public/icons/gemini.jpeg b/app/public/icons/gemini.jpeg new file mode 100644 index 0000000..e16bcfa Binary files /dev/null and b/app/public/icons/gemini.jpeg differ diff --git a/app/src/admin/channel.ts b/app/src/admin/channel.ts index 934707c..0284579 100644 --- a/app/src/admin/channel.ts +++ b/app/src/admin/channel.ts @@ -33,7 +33,7 @@ export const ChannelTypes: Record = { baichuan: "百川 AI", skylark: "火山方舟", bing: "New Bing", - palm: "Google PaLM2", + palm: "Google Gemini", midjourney: "Midjourney", oneapi: "Nio API", }; @@ -141,7 +141,7 @@ export const ChannelInfos: Record = { id: 11, endpoint: "https://generativelanguage.googleapis.com", format: "", - models: ["chat-bison-001"], + models: ["chat-bison-001", "gemini-pro", "gemini-pro-vision"], }, midjourney: { id: 12, diff --git a/app/src/admin/colors.ts b/app/src/admin/colors.ts index 173cd24..4e8c8f5 100644 --- a/app/src/admin/colors.ts +++ b/app/src/admin/colors.ts @@ -43,6 +43,8 @@ export const modelColorMapper: Record = { "spark-desk-v3": "#06b3e8", "chat-bison-001": "#f82a53", + "gemini-pro": "#f82a53", + "gemini-pro-vision": "#f82a53", "bing-creative": "#2673e7", "bing-balanced": "#2673e7", diff --git a/app/src/conf.ts b/app/src/conf.ts index 996af21..fdc6233 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -274,6 +274,22 @@ export const supportModels: Model[] = [ tag: ["free", "english-model"], }, + // gemini + { + id: "gemini-pro", + name: "Gemini Pro", + free: true, + auth: true, + tag: ["free", "official"], + }, + { + id: "gemini-pro-vision", + name: "Gemini Pro Vision", + free: true, + auth: true, + tag: ["free", "official", "multi-modal"], + }, + // drawing models { id: "midjourney", @@ -346,6 +362,9 @@ export const defaultModels = [ "zhipu-chatglm-turbo", "baichuan-53b", + "gemini-pro", + "gemini-pro-vision", + "dall-e-2", "midjourney-fast", "stable-diffusion", @@ -412,6 +431,8 @@ export const modelAvatars: Record = { "midjourney-turbo": "midjourney.jpg", "bing-creative": "newbing.jpg", "chat-bison-001": "palm2.webp", + "gemini-pro": "gemini.jpeg", + "gemini-pro-vision": "gemini.jpeg", "zhipu-chatglm-turbo": "chatglm.png", "qwen-plus-net": "tongyi.png", "qwen-plus": "tongyi.png", diff --git a/app/src/routes/admin/System.tsx b/app/src/routes/admin/System.tsx index 5da75ff..76c1802 100644 --- a/app/src/routes/admin/System.tsx +++ b/app/src/routes/admin/System.tsx @@ -168,8 +168,8 @@ function Mail({ data, dispatch, onChange }: CompProps) {
- - diff --git a/globals/variables.go b/globals/variables.go index 7afc448..90ac7a6 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -69,6 +69,8 @@ const ( SparkDeskV2 = "spark-desk-v2" SparkDeskV3 = "spark-desk-v3" ChatBison001 = "chat-bison-001" + GeminiPro = "gemini-pro" + GeminiProVision = "gemini-pro-vision" BingCreative = "bing-creative" BingBalanced = "bing-balanced" BingPrecise = "bing-precise" diff --git a/utils/char.go b/utils/char.go index f3e5baf..57ab51e 100644 --- a/utils/char.go +++ b/utils/char.go @@ -167,8 +167,8 @@ func ExtractUrls(data string) []string { func ExtractImageUrls(data string) []string { // https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload - re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp))`) - return re.FindAllString(data, -1) + re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic))`) + return re.FindAllString(strings.ToLower(data), -1) } func ContainUnicode(data string) bool { diff --git a/utils/encrypt.go b/utils/encrypt.go index 7fc8b80..b7ee38c 100644 --- a/utils/encrypt.go +++ b/utils/encrypt.go @@ -6,6 +6,7 @@ import ( "crypto/md5" crand "crypto/rand" "crypto/sha256" + "encoding/base64" "encoding/hex" "io" ) @@ -22,6 +23,30 @@ func Sha2EncryptForm(form interface{}) string { return hex.EncodeToString(hash[:]) } +func Base64Encode(raw string) string { + return base64.StdEncoding.EncodeToString([]byte(raw)) +} + +func Base64EncodeBytes(raw []byte) string { + return base64.StdEncoding.EncodeToString(raw) +} + +func Base64Decode(raw string) string { + if data, err := base64.StdEncoding.DecodeString(raw); err == nil { + return string(data) + } else { + return "" + } +} + +func Base64DecodeBytes(raw string) []byte { + if data, err := base64.StdEncoding.DecodeString(raw); err == nil { + return data + } else { + return []byte{} + } +} + func Md5Encrypt(raw string) string { // return 32-bit hash hash := md5.Sum([]byte(raw)) diff --git a/utils/image.go b/utils/image.go index d789b0c..da24526 100644 --- a/utils/image.go +++ b/utils/image.go @@ -6,6 +6,7 @@ import ( "image" "image/gif" "image/jpeg" + "io" "math" "net/http" "path" @@ -51,6 +52,22 @@ func NewImage(url string) (*Image, error) { return &Image{Object: img}, nil } +func ConvertToBase64(url string) (string, error) { + res, err := http.Get(url) + if err != nil { + return "", err + } + + defer res.Body.Close() + + data, err := io.ReadAll(res.Body) + if err != nil { + return "", err + } + + return Base64EncodeBytes(data), nil +} + func (i *Image) GetWidth() int { return i.Object.Bounds().Max.X }