mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 13:00:14 +09:00
feat: add gemini pro, gemini pro vision models
This commit is contained in:
parent
3d1e8c8f87
commit
261e500840
@ -6,13 +6,23 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var geminiMaxImages = 16
|
||||
|
||||
type ChatProps struct {
|
||||
Model string
|
||||
Message []globals.Message
|
||||
Model string
|
||||
Message []globals.Message
|
||||
Temperature *float64
|
||||
TopP *float64
|
||||
TopK *int
|
||||
MaxOutputTokens *int
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetChatEndpoint(model string) string {
|
||||
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
|
||||
if model == globals.ChatBison001 {
|
||||
return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", c.Endpoint, model, c.ApiKey)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage {
|
||||
@ -41,33 +51,75 @@ func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage {
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetChatBody(props *ChatProps) *ChatBody {
|
||||
return &ChatBody{
|
||||
Prompt: Prompt{
|
||||
func (c *ChatInstance) GetPalm2ChatBody(props *ChatProps) *PalmChatBody {
|
||||
return &PalmChatBody{
|
||||
Prompt: PalmPrompt{
|
||||
Messages: c.ConvertMessage(props.Message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
uri := c.GetChatEndpoint(props.Model)
|
||||
data, err := utils.Post(uri, map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
}, c.GetChatBody(props))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("palm2 error: %s", err.Error())
|
||||
func (c *ChatInstance) GetGeminiChatBody(props *ChatProps) *GeminiChatBody {
|
||||
return &GeminiChatBody{
|
||||
Contents: c.GetGeminiContents(props.Model, props.Message),
|
||||
GenerationConfig: GeminiConfig{
|
||||
Temperature: props.Temperature,
|
||||
MaxOutputTokens: props.MaxOutputTokens,
|
||||
TopP: props.TopP,
|
||||
TopK: props.TopK,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if form := utils.MapToStruct[ChatResponse](data); form != nil {
|
||||
func (c *ChatInstance) GetPalm2ChatResponse(data interface{}) (string, error) {
|
||||
if form := utils.MapToStruct[PalmChatResponse](data); form != nil {
|
||||
if len(form.Candidates) == 0 {
|
||||
return "I don't know how to respond to that. Please try another question.", nil
|
||||
return "", fmt.Errorf("palm2 error: the content violates content policy")
|
||||
}
|
||||
return form.Candidates[0].Content, nil
|
||||
}
|
||||
return "", fmt.Errorf("palm2 error: cannot parse response")
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) {
|
||||
if form := utils.MapToStruct[GeminiChatResponse](data); form != nil {
|
||||
if len(form.Candidates) != 0 && len(form.Candidates[0].Content.Parts) != 0 {
|
||||
return form.Candidates[0].Content.Parts[0].Text, nil
|
||||
}
|
||||
}
|
||||
|
||||
if form := utils.MapToStruct[GeminiChatErrorResponse](data); form != nil {
|
||||
return "", fmt.Errorf("gemini error: %s (code: %d, status: %s)", form.Error.Message, form.Error.Code, form.Error.Status)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("gemini: cannot parse response")
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
uri := c.GetChatEndpoint(props.Model)
|
||||
|
||||
if props.Model == globals.ChatBison001 {
|
||||
data, err := utils.Post(uri, map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
}, c.GetPalm2ChatBody(props))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("palm2 error: %s", err.Error())
|
||||
}
|
||||
return c.GetPalm2ChatResponse(data)
|
||||
}
|
||||
|
||||
data, err := utils.Post(uri, map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
}, c.GetGeminiChatBody(props))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gemini error: %s", err.Error())
|
||||
}
|
||||
|
||||
return c.GetGeminiChatResponse(data)
|
||||
}
|
||||
|
||||
// CreateStreamChatRequest is the mock stream request for palm2
|
||||
// tips: palm2 does not support stream request
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||
|
106
adapter/palm2/formatter.go
Normal file
106
adapter/palm2/formatter.go
Normal file
@ -0,0 +1,106 @@
|
||||
package palm2
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getGeminiRole(role string) string {
|
||||
switch role {
|
||||
case globals.User:
|
||||
return GeminiUserType
|
||||
case globals.Assistant, globals.Tool, globals.System:
|
||||
return GeminiModelType
|
||||
default:
|
||||
return GeminiUserType
|
||||
}
|
||||
}
|
||||
|
||||
func getMimeType(content string) string {
|
||||
segment := strings.Split(content, ".")
|
||||
if len(segment) == 0 || len(segment) == 1 {
|
||||
return "image/png"
|
||||
}
|
||||
|
||||
suffix := strings.TrimSpace(strings.ToLower(segment[len(segment)-1]))
|
||||
|
||||
switch suffix {
|
||||
case "png":
|
||||
return "image/png"
|
||||
case "jpg", "jpeg":
|
||||
return "image/jpeg"
|
||||
case "gif":
|
||||
return "image/gif"
|
||||
case "webp":
|
||||
return "image/webp"
|
||||
case "heif":
|
||||
return "image/heif"
|
||||
case "heic":
|
||||
return "image/heic"
|
||||
default:
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart {
|
||||
parts = append(parts, GeminiChatPart{
|
||||
Text: &content,
|
||||
})
|
||||
|
||||
if model == globals.GeminiPro {
|
||||
return parts
|
||||
}
|
||||
|
||||
urls := utils.ExtractImageUrls(content)
|
||||
if len(urls) > geminiMaxImages {
|
||||
urls = urls[:geminiMaxImages]
|
||||
}
|
||||
|
||||
for _, url := range urls {
|
||||
data, err := utils.ConvertToBase64(url)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiChatPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: getMimeType(url),
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetGeminiContents(model string, message []globals.Message) []GeminiContent {
|
||||
// gemini role should be user-model
|
||||
|
||||
result := make([]GeminiContent, 0)
|
||||
for _, item := range message {
|
||||
role := getGeminiRole(item.Role)
|
||||
if len(item.Content) == 0 {
|
||||
// gemini model: message must include non empty content
|
||||
continue
|
||||
}
|
||||
|
||||
if len(result) == 0 && getGeminiRole(item.Role) == GeminiModelType {
|
||||
// gemini model: first message must be user
|
||||
continue
|
||||
}
|
||||
|
||||
if len(result) > 0 && role == result[len(result)-1].Role {
|
||||
// gemini model: messages must alternate between authors
|
||||
result[len(result)-1].Parts = getGeminiContent(result[len(result)-1].Parts, item.Content, model)
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, GeminiContent{
|
||||
Role: getGeminiRole(item.Role),
|
||||
Parts: getGeminiContent(make([]GeminiChatPart, 0), item.Content, model),
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
@ -1,20 +1,72 @@
|
||||
package palm2
|
||||
|
||||
const (
|
||||
GeminiUserType = "user"
|
||||
GeminiModelType = "model"
|
||||
)
|
||||
|
||||
type PalmMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatBody is the native http request body for palm2
|
||||
type ChatBody struct {
|
||||
Prompt Prompt `json:"prompt"`
|
||||
// PalmChatBody is the native http request body for palm2
|
||||
type PalmChatBody struct {
|
||||
Prompt PalmPrompt `json:"prompt"`
|
||||
}
|
||||
|
||||
type Prompt struct {
|
||||
type PalmPrompt struct {
|
||||
Messages []PalmMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// ChatResponse is the native http response body for palm2
|
||||
type ChatResponse struct {
|
||||
// PalmChatResponse is the native http response body for palm2
|
||||
type PalmChatResponse struct {
|
||||
Candidates []PalmMessage `json:"candidates"`
|
||||
}
|
||||
|
||||
// GeminiChatBody is the native http request body for gemini
|
||||
type GeminiChatBody struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
GenerationConfig GeminiConfig `json:"generationConfig"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role"`
|
||||
Parts []GeminiChatPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatPart struct {
|
||||
Text *string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inline_data,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mime_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
Role string `json:"role"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
|
||||
type GeminiChatErrorResponse struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
BIN
app/public/icons/gemini.jpeg
Normal file
BIN
app/public/icons/gemini.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 5.4 KiB |
@ -33,7 +33,7 @@ export const ChannelTypes: Record<string, string> = {
|
||||
baichuan: "百川 AI",
|
||||
skylark: "火山方舟",
|
||||
bing: "New Bing",
|
||||
palm: "Google PaLM2",
|
||||
palm: "Google Gemini",
|
||||
midjourney: "Midjourney",
|
||||
oneapi: "Nio API",
|
||||
};
|
||||
@ -141,7 +141,7 @@ export const ChannelInfos: Record<string, ChannelInfo> = {
|
||||
id: 11,
|
||||
endpoint: "https://generativelanguage.googleapis.com",
|
||||
format: "<api-key>",
|
||||
models: ["chat-bison-001"],
|
||||
models: ["chat-bison-001", "gemini-pro", "gemini-pro-vision"],
|
||||
},
|
||||
midjourney: {
|
||||
id: 12,
|
||||
|
@ -43,6 +43,8 @@ export const modelColorMapper: Record<string, string> = {
|
||||
"spark-desk-v3": "#06b3e8",
|
||||
|
||||
"chat-bison-001": "#f82a53",
|
||||
"gemini-pro": "#f82a53",
|
||||
"gemini-pro-vision": "#f82a53",
|
||||
|
||||
"bing-creative": "#2673e7",
|
||||
"bing-balanced": "#2673e7",
|
||||
|
@ -274,6 +274,22 @@ export const supportModels: Model[] = [
|
||||
tag: ["free", "english-model"],
|
||||
},
|
||||
|
||||
// gemini
|
||||
{
|
||||
id: "gemini-pro",
|
||||
name: "Gemini Pro",
|
||||
free: true,
|
||||
auth: true,
|
||||
tag: ["free", "official"],
|
||||
},
|
||||
{
|
||||
id: "gemini-pro-vision",
|
||||
name: "Gemini Pro Vision",
|
||||
free: true,
|
||||
auth: true,
|
||||
tag: ["free", "official", "multi-modal"],
|
||||
},
|
||||
|
||||
// drawing models
|
||||
{
|
||||
id: "midjourney",
|
||||
@ -346,6 +362,9 @@ export const defaultModels = [
|
||||
"zhipu-chatglm-turbo",
|
||||
"baichuan-53b",
|
||||
|
||||
"gemini-pro",
|
||||
"gemini-pro-vision",
|
||||
|
||||
"dall-e-2",
|
||||
"midjourney-fast",
|
||||
"stable-diffusion",
|
||||
@ -412,6 +431,8 @@ export const modelAvatars: Record<string, string> = {
|
||||
"midjourney-turbo": "midjourney.jpg",
|
||||
"bing-creative": "newbing.jpg",
|
||||
"chat-bison-001": "palm2.webp",
|
||||
"gemini-pro": "gemini.jpeg",
|
||||
"gemini-pro-vision": "gemini.jpeg",
|
||||
"zhipu-chatglm-turbo": "chatglm.png",
|
||||
"qwen-plus-net": "tongyi.png",
|
||||
"qwen-plus": "tongyi.png",
|
||||
|
@ -168,8 +168,8 @@ function Mail({ data, dispatch, onChange }: CompProps<MailState>) {
|
||||
<ParagraphFooter>
|
||||
<div className={`grow`} />
|
||||
<Dialog open={mailDialog} onOpenChange={setMailDialog}>
|
||||
<DialogTrigger>
|
||||
<Button variant={`outline`} size={`sm`} loading={true}>
|
||||
<DialogTrigger asChild>
|
||||
<Button variant={`outline`} size={`sm`}>
|
||||
{t("admin.system.test")}
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
|
@ -69,6 +69,8 @@ const (
|
||||
SparkDeskV2 = "spark-desk-v2"
|
||||
SparkDeskV3 = "spark-desk-v3"
|
||||
ChatBison001 = "chat-bison-001"
|
||||
GeminiPro = "gemini-pro"
|
||||
GeminiProVision = "gemini-pro-vision"
|
||||
BingCreative = "bing-creative"
|
||||
BingBalanced = "bing-balanced"
|
||||
BingPrecise = "bing-precise"
|
||||
|
@ -167,8 +167,8 @@ func ExtractUrls(data string) []string {
|
||||
func ExtractImageUrls(data string) []string {
|
||||
// https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
|
||||
|
||||
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp))`)
|
||||
return re.FindAllString(data, -1)
|
||||
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic))`)
|
||||
return re.FindAllString(strings.ToLower(data), -1)
|
||||
}
|
||||
|
||||
func ContainUnicode(data string) bool {
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"crypto/md5"
|
||||
crand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
)
|
||||
@ -22,6 +23,30 @@ func Sha2EncryptForm(form interface{}) string {
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func Base64Encode(raw string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(raw))
|
||||
}
|
||||
|
||||
func Base64EncodeBytes(raw []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(raw)
|
||||
}
|
||||
|
||||
func Base64Decode(raw string) string {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw); err == nil {
|
||||
return string(data)
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func Base64DecodeBytes(raw string) []byte {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw); err == nil {
|
||||
return data
|
||||
} else {
|
||||
return []byte{}
|
||||
}
|
||||
}
|
||||
|
||||
func Md5Encrypt(raw string) string {
|
||||
// return 32-bit hash
|
||||
hash := md5.Sum([]byte(raw))
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"image"
|
||||
"image/gif"
|
||||
"image/jpeg"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"path"
|
||||
@ -51,6 +52,22 @@ func NewImage(url string) (*Image, error) {
|
||||
return &Image{Object: img}, nil
|
||||
}
|
||||
|
||||
func ConvertToBase64(url string) (string, error) {
|
||||
res, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return Base64EncodeBytes(data), nil
|
||||
}
|
||||
|
||||
func (i *Image) GetWidth() int {
|
||||
return i.Object.Bounds().Max.X
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user