feat: add gemini pro, gemini pro vision models

This commit is contained in:
Zhang Minghan 2023-12-28 09:32:14 +08:00
parent 3d1e8c8f87
commit 261e500840
12 changed files with 305 additions and 28 deletions

View File

@ -6,13 +6,23 @@ import (
"fmt" "fmt"
) )
var geminiMaxImages = 16
type ChatProps struct { type ChatProps struct {
Model string Model string
Message []globals.Message Message []globals.Message
Temperature *float64
TopP *float64
TopK *int
MaxOutputTokens *int
} }
func (c *ChatInstance) GetChatEndpoint(model string) string { func (c *ChatInstance) GetChatEndpoint(model string) string {
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey) if model == globals.ChatBison001 {
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
}
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
} }
func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage { func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage {
@ -41,33 +51,75 @@ func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage {
return result return result
} }
func (c *ChatInstance) GetChatBody(props *ChatProps) *ChatBody { func (c *ChatInstance) GetPalm2ChatBody(props *ChatProps) *PalmChatBody {
return &ChatBody{ return &PalmChatBody{
Prompt: Prompt{ Prompt: PalmPrompt{
Messages: c.ConvertMessage(props.Message), Messages: c.ConvertMessage(props.Message),
}, },
} }
} }
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { func (c *ChatInstance) GetGeminiChatBody(props *ChatProps) *GeminiChatBody {
uri := c.GetChatEndpoint(props.Model) return &GeminiChatBody{
data, err := utils.Post(uri, map[string]string{ Contents: c.GetGeminiContents(props.Model, props.Message),
"Content-Type": "application/json", GenerationConfig: GeminiConfig{
}, c.GetChatBody(props)) Temperature: props.Temperature,
MaxOutputTokens: props.MaxOutputTokens,
if err != nil { TopP: props.TopP,
return "", fmt.Errorf("palm2 error: %s", err.Error()) TopK: props.TopK,
},
} }
}
if form := utils.MapToStruct[ChatResponse](data); form != nil { func (c *ChatInstance) GetPalm2ChatResponse(data interface{}) (string, error) {
if form := utils.MapToStruct[PalmChatResponse](data); form != nil {
if len(form.Candidates) == 0 { if len(form.Candidates) == 0 {
return "I don't know how to respond to that. Please try another question.", nil return "", fmt.Errorf("palm2 error: the content violates content policy")
} }
return form.Candidates[0].Content, nil return form.Candidates[0].Content, nil
} }
return "", fmt.Errorf("palm2 error: cannot parse response") return "", fmt.Errorf("palm2 error: cannot parse response")
} }
func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) {
if form := utils.MapToStruct[GeminiChatResponse](data); form != nil {
if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 {
return form.Candidates[0].Content.Parts[0].Text, nil
}
}
if form := utils.MapToStruct[GeminiChatErrorResponse](data); form != nil {
return "", fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
}
return "", fmt.Errorf("gemini: cannot parse response")
}
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
uri := c.GetChatEndpoint(props.Model)
if props.Model == globals.ChatBison001 {
data, err := utils.Post(uri, map[string]string{
"Content-Type": "application/json",
}, c.GetPalm2ChatBody(props))
if err != nil {
return "", fmt.Errorf("palm2 error: %s", err.Error())
}
return c.GetPalm2ChatResponse(data)
}
data, err := utils.Post(uri, map[string]string{
"Content-Type": "application/json",
}, c.GetGeminiChatBody(props))
if err != nil {
return "", fmt.Errorf("gemini error: %s", err.Error())
}
return c.GetGeminiChatResponse(data)
}
// CreateStreamChatRequest is the mock stream request for palm2 // CreateStreamChatRequest is the mock stream request for palm2
// tips: palm2 does not support stream request // tips: palm2 does not support stream request
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {

106
adapter/palm2/formatter.go Normal file
View File

@ -0,0 +1,106 @@
package palm2
import (
"chat/globals"
"chat/utils"
"strings"
)
func getGeminiRole(role string) string {
switch role {
case globals.User:
return GeminiUserType
case globals.Assistant, globals.Tool, globals.System:
return GeminiModelType
default:
return GeminiUserType
}
}
func getMimeType(content string) string {
segment := strings.Split(content, ".")
if len(segment) == 0 || len(segment) == 1 {
return "image/png"
}
suffix := strings.TrimSpace(strings.ToLower(segment[len(segment)-1]))
switch suffix {
case "png":
return "image/png"
case "jpg", "jpeg":
return "image/jpeg"
case "gif":
return "image/gif"
case "webp":
return "image/webp"
case "heif":
return "image/heif"
case "heic":
return "image/heic"
default:
return "image/png"
}
}
func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart {
parts = append(parts, GeminiChatPart{
Text: &content,
})
if model == globals.GeminiPro {
return parts
}
urls := utils.ExtractImageUrls(content)
if len(urls) > geminiMaxImages {
urls = urls[:geminiMaxImages]
}
for _, url := range urls {
data, err := utils.ConvertToBase64(url)
if err != nil {
continue
}
parts = append(parts, GeminiChatPart{
InlineData: &GeminiInlineData{
MimeType: getMimeType(url),
Data: data,
},
})
}
return parts
}
func (c *ChatInstance) GetGeminiContents(model string, message []globals.Message) []GeminiContent {
// gemini role should be user-model
result := make([]GeminiContent, 0)
for _, item := range message {
role := getGeminiRole(item.Role)
if len(item.Content) == 0 {
// gemini model: message must include non empty content
continue
}
if len(result) == 0 && getGeminiRole(item.Role) == GeminiModelType {
// gemini model: first message must be user
continue
}
if len(result) > 0 && role == result[len(result)-1].Role {
// gemini model: messages must alternate between authors
result[len(result)-1].Parts = getGeminiContent(result[len(result)-1].Parts, item.Content, model)
continue
}
result = append(result, GeminiContent{
Role: getGeminiRole(item.Role),
Parts: getGeminiContent(make([]GeminiChatPart, 0), item.Content, model),
})
}
return result
}

View File

@ -1,20 +1,72 @@
package palm2 package palm2
const (
GeminiUserType = "user"
GeminiModelType = "model"
)
type PalmMessage struct { type PalmMessage struct {
Author string `json:"author"` Author string `json:"author"`
Content string `json:"content"` Content string `json:"content"`
} }
// ChatBody is the native http request body for palm2 // PalmChatBody is the native http request body for palm2
type ChatBody struct { type PalmChatBody struct {
Prompt Prompt `json:"prompt"` Prompt PalmPrompt `json:"prompt"`
} }
type Prompt struct { type PalmPrompt struct {
Messages []PalmMessage `json:"messages"` Messages []PalmMessage `json:"messages"`
} }
// ChatResponse is the native http response body for palm2 // PalmChatResponse is the native http response body for palm2
type ChatResponse struct { type PalmChatResponse struct {
Candidates []PalmMessage `json:"candidates"` Candidates []PalmMessage `json:"candidates"`
} }
// GeminiChatBody is the native http request body for gemini
type GeminiChatBody struct {
Contents []GeminiContent `json:"contents"`
GenerationConfig GeminiConfig `json:"generationConfig"`
}
type GeminiConfig struct {
Temperature *float64 `json:"temperature,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}
type GeminiContent struct {
Role string `json:"role"`
Parts []GeminiChatPart `json:"parts"`
}
type GeminiChatPart struct {
Text *string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inline_data,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mime_type"`
Data string `json:"data"`
}
type GeminiChatResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
} `json:"candidates"`
}
type GeminiChatErrorResponse struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.4 KiB

View File

@ -33,7 +33,7 @@ export const ChannelTypes: Record<string, string> = {
baichuan: "百川 AI", baichuan: "百川 AI",
skylark: "火山方舟", skylark: "火山方舟",
bing: "New Bing", bing: "New Bing",
palm: "Google PaLM2", palm: "Google Gemini",
midjourney: "Midjourney", midjourney: "Midjourney",
oneapi: "Nio API", oneapi: "Nio API",
}; };
@ -141,7 +141,7 @@ export const ChannelInfos: Record<string, ChannelInfo> = {
id: 11, id: 11,
endpoint: "https://generativelanguage.googleapis.com", endpoint: "https://generativelanguage.googleapis.com",
format: "<api-key>", format: "<api-key>",
models: ["chat-bison-001"], models: ["chat-bison-001", "gemini-pro", "gemini-pro-vision"],
}, },
midjourney: { midjourney: {
id: 12, id: 12,

View File

@ -43,6 +43,8 @@ export const modelColorMapper: Record<string, string> = {
"spark-desk-v3": "#06b3e8", "spark-desk-v3": "#06b3e8",
"chat-bison-001": "#f82a53", "chat-bison-001": "#f82a53",
"gemini-pro": "#f82a53",
"gemini-pro-vision": "#f82a53",
"bing-creative": "#2673e7", "bing-creative": "#2673e7",
"bing-balanced": "#2673e7", "bing-balanced": "#2673e7",

View File

@ -274,6 +274,22 @@ export const supportModels: Model[] = [
tag: ["free", "english-model"], tag: ["free", "english-model"],
}, },
// gemini
{
id: "gemini-pro",
name: "Gemini Pro",
free: true,
auth: true,
tag: ["free", "official"],
},
{
id: "gemini-pro-vision",
name: "Gemini Pro Vision",
free: true,
auth: true,
tag: ["free", "official", "multi-modal"],
},
// drawing models // drawing models
{ {
id: "midjourney", id: "midjourney",
@ -346,6 +362,9 @@ export const defaultModels = [
"zhipu-chatglm-turbo", "zhipu-chatglm-turbo",
"baichuan-53b", "baichuan-53b",
"gemini-pro",
"gemini-pro-vision",
"dall-e-2", "dall-e-2",
"midjourney-fast", "midjourney-fast",
"stable-diffusion", "stable-diffusion",
@ -412,6 +431,8 @@ export const modelAvatars: Record<string, string> = {
"midjourney-turbo": "midjourney.jpg", "midjourney-turbo": "midjourney.jpg",
"bing-creative": "newbing.jpg", "bing-creative": "newbing.jpg",
"chat-bison-001": "palm2.webp", "chat-bison-001": "palm2.webp",
"gemini-pro": "gemini.jpeg",
"gemini-pro-vision": "gemini.jpeg",
"zhipu-chatglm-turbo": "chatglm.png", "zhipu-chatglm-turbo": "chatglm.png",
"qwen-plus-net": "tongyi.png", "qwen-plus-net": "tongyi.png",
"qwen-plus": "tongyi.png", "qwen-plus": "tongyi.png",

View File

@ -168,8 +168,8 @@ function Mail({ data, dispatch, onChange }: CompProps<MailState>) {
<ParagraphFooter> <ParagraphFooter>
<div className={`grow`} /> <div className={`grow`} />
<Dialog open={mailDialog} onOpenChange={setMailDialog}> <Dialog open={mailDialog} onOpenChange={setMailDialog}>
<DialogTrigger> <DialogTrigger asChild>
<Button variant={`outline`} size={`sm`} loading={true}> <Button variant={`outline`} size={`sm`}>
{t("admin.system.test")} {t("admin.system.test")}
</Button> </Button>
</DialogTrigger> </DialogTrigger>

View File

@ -69,6 +69,8 @@ const (
SparkDeskV2 = "spark-desk-v2" SparkDeskV2 = "spark-desk-v2"
SparkDeskV3 = "spark-desk-v3" SparkDeskV3 = "spark-desk-v3"
ChatBison001 = "chat-bison-001" ChatBison001 = "chat-bison-001"
GeminiPro = "gemini-pro"
GeminiProVision = "gemini-pro-vision"
BingCreative = "bing-creative" BingCreative = "bing-creative"
BingBalanced = "bing-balanced" BingBalanced = "bing-balanced"
BingPrecise = "bing-precise" BingPrecise = "bing-precise"

View File

@ -167,8 +167,8 @@ func ExtractUrls(data string) []string {
func ExtractImageUrls(data string) []string { func ExtractImageUrls(data string) []string {
// https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload // https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp))`) re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic))`)
return re.FindAllString(data, -1) return re.FindAllString(strings.ToLower(data), -1)
} }
func ContainUnicode(data string) bool { func ContainUnicode(data string) bool {

View File

@ -6,6 +6,7 @@ import (
"crypto/md5" "crypto/md5"
crand "crypto/rand" crand "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64"
"encoding/hex" "encoding/hex"
"io" "io"
) )
@ -22,6 +23,30 @@ func Sha2EncryptForm(form interface{}) string {
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
} }
func Base64Encode(raw string) string {
return base64.StdEncoding.EncodeToString([]byte(raw))
}
func Base64EncodeBytes(raw []byte) string {
return base64.StdEncoding.EncodeToString(raw)
}
func Base64Decode(raw string) string {
if data, err := base64.StdEncoding.DecodeString(raw); err == nil {
return string(data)
} else {
return ""
}
}
func Base64DecodeBytes(raw string) []byte {
if data, err := base64.StdEncoding.DecodeString(raw); err == nil {
return data
} else {
return []byte{}
}
}
func Md5Encrypt(raw string) string { func Md5Encrypt(raw string) string {
// return 32-bit hash // return 32-bit hash
hash := md5.Sum([]byte(raw)) hash := md5.Sum([]byte(raw))

View File

@ -6,6 +6,7 @@ import (
"image" "image"
"image/gif" "image/gif"
"image/jpeg" "image/jpeg"
"io"
"math" "math"
"net/http" "net/http"
"path" "path"
@ -51,6 +52,22 @@ func NewImage(url string) (*Image, error) {
return &Image{Object: img}, nil return &Image{Object: img}, nil
} }
func ConvertToBase64(url string) (string, error) {
res, err := http.Get(url)
if err != nil {
return "", err
}
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}
return Base64EncodeBytes(data), nil
}
func (i *Image) GetWidth() int { func (i *Image) GetWidth() int {
return i.Object.Bounds().Max.X return i.Object.Bounds().Max.X
} }