add gpt-4-vision-preview support

This commit is contained in:
Zhang Minghan 2023-11-21 23:03:13 +08:00
parent bb783f04b2
commit 4ee2051ca9
7 changed files with 128 additions and 16 deletions

View File

@ -24,16 +24,59 @@ func processFormat(data string) string {
return item return item
} }
func formatMessages(props *ChatProps) []globals.Message { func formatMessages(props *ChatProps) interface{} {
if props.Model == globals.GPT4Vision { if props.Model == globals.GPT4Vision {
base := props.Message[len(props.Message)-1].Content base := props.Message[len(props.Message)-1].Content
urls := utils.ExtractUrls(base) urls := utils.ExtractImageUrls(base)
if len(urls) > 0 { if len(urls) > 0 {
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base) base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
} }
props.Message[len(props.Message)-1].Content = base props.Message[len(props.Message)-1].Content = base
return props.Message return props.Message
} else if props.Model == globals.GPT41106VisionPreview {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
urls := utils.ExtractImageUrls(message.Content)
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
if err != nil {
return nil
}
props.Buffer.AddImage(obj)
return &MessageContent{
Type: "image_url",
ImageUrl: &ImageUrl{
Url: url,
},
}
})
return Message{
Role: message.Role,
Content: utils.Prepend(images, MessageContent{
Type: "text",
Text: &message.Content,
}),
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}
}
return Message{
Role: message.Role,
Content: MessageContents{
MessageContent{
Type: "text",
Text: &message.Content,
},
},
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}
})
} }
return props.Message return props.Message

View File

@ -2,10 +2,30 @@ package chatgpt
import "chat/globals" import "chat/globals"
type ImageUrl struct {
Url string `json:"url"`
Detail *string `json:"detail,omitempty"`
}
type MessageContent struct {
Type string `json:"type"`
Text *string `json:"text,omitempty"`
ImageUrl *ImageUrl `json:"image_url,omitempty"`
}
type MessageContents []MessageContent
type Message struct {
Role string `json:"role"`
Content MessageContents `json:"content"`
ToolCallId *string `json:"tool_call_id,omitempty"` // only `tool` role
ToolCalls *globals.ToolCalls `json:"tool_calls,omitempty"` // only `assistant` role
}
// ChatRequest is the request body for chatgpt // ChatRequest is the request body for chatgpt
type ChatRequest struct { type ChatRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []globals.Message `json:"messages"` Messages interface{} `json:"messages"`
MaxToken *int `json:"max_tokens,omitempty"` MaxToken *int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
PresencePenalty *float32 `json:"presence_penalty,omitempty"` PresencePenalty *float32 `json:"presence_penalty,omitempty"`
@ -16,7 +36,7 @@ type ChatRequest struct {
ToolChoice *interface{} `json:"tool_choice,omitempty"` // string or object ToolChoice *interface{} `json:"tool_choice,omitempty"` // string or object
} }
// CompletionRequest ChatRequest is the request body for chatgpt completion // CompletionRequest is the request body for chatgpt completion
type CompletionRequest struct { type CompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`

View File

@ -3,6 +3,7 @@ package skylark
import ( import (
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"fmt"
"github.com/volcengine/volc-sdk-golang/service/maas" "github.com/volcengine/volc-sdk-golang/service/maas"
"github.com/volcengine/volc-sdk-golang/service/maas/models/api" "github.com/volcengine/volc-sdk-golang/service/maas/models/api"
) )
@ -19,6 +20,7 @@ type ChatProps struct {
TopP *float32 TopP *float32
TopK *int TopK *int
Tools *globals.FunctionTools Tools *globals.FunctionTools
Buffer utils.Buffer
} }
func getMessages(messages []globals.Message) []*api.Message { func getMessages(messages []globals.Message) []*api.Message {
@ -54,6 +56,27 @@ func (c *ChatInstance) CreateRequest(props *ChatProps) *api.ChatReq {
} }
} }
func getChoice(choice *api.ChatResp, buffer utils.Buffer) string {
if choice == nil {
return ""
}
calls := choice.Choice.Message.FunctionCall
if calls != nil {
buffer.SetToolCalls(&globals.ToolCalls{
globals.ToolCall{
Type: "function",
Id: globals.ToolCallId(fmt.Sprintf("%s-%s", calls.Name, choice.ReqId)),
Function: globals.ToolCallFunction{
Name: calls.Name,
Arguments: calls.Arguments,
},
},
})
}
return choice.Choice.Message.Content
}
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
req := c.CreateRequest(props) req := c.CreateRequest(props)
channel, err := c.Instance.StreamChat(req) channel, err := c.Instance.StreamChat(req)
@ -66,7 +89,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
return partial.Error return partial.Error
} }
if err := callback(partial.Choice.Message.Content); err != nil { if err := callback(getChoice(partial, props.Buffer)); err != nil {
return err return err
} }
} }

View File

@ -72,12 +72,19 @@ func getChoice(form *ChatResponse, buffer utils.Buffer) string {
return "" return ""
} }
buffer.SetToolCalls(&globals.ToolCalls{ if resp[0].FunctionCall != nil {
globals.ToolCall{ buffer.SetToolCalls(&globals.ToolCalls{
Type: "text", globals.ToolCall{
Id: globals.ToolCallId(form.Header.Sid), Type: "function",
}, Id: globals.ToolCallId(fmt.Sprintf("%s-%s", resp[0].FunctionCall.Name, resp[0].FunctionCall.Arguments)),
}) Function: globals.ToolCallFunction{
Name: resp[0].FunctionCall.Name,
Arguments: resp[0].FunctionCall.Arguments,
},
},
})
}
return resp[0].Content return resp[0].Content
} }

View File

@ -53,6 +53,22 @@ func InsertSlice[T any](arr []T, index int, value []T) []T {
return arr return arr
} }
func Append[T any](arr []T, value T) []T {
return append(arr, value)
}
func AppendSlice[T any](arr []T, value []T) []T {
return append(arr, value...)
}
func Prepend[T any](arr []T, value T) []T {
return append([]T{value}, arr...)
}
func PrependSlice[T any](arr []T, value []T) []T {
return append(value, arr...)
}
func Remove[T any](arr []T, index int) []T { func Remove[T any](arr []T, index int) []T {
return append(arr[:index], arr[index+1:]...) return append(arr[:index], arr[index+1:]...)
} }

View File

@ -45,12 +45,10 @@ func (b *Buffer) GetChunk() string {
return b.Latest return b.Latest
} }
func (b *Buffer) SetImages(images Images) { func (b *Buffer) AddImage(image *Image) {
b.Images = images b.Images = append(b.Images, *image)
b.Quota += Sum(Each(images, func(image Image) float32 { b.Quota += float32(image.CountTokens(b.Model)) * 0.7
return float32(image.CountTokens(b.Model)) * 0.7
}))
} }
func (b *Buffer) GetImages() Images { func (b *Buffer) GetImages() Images {

View File

@ -137,3 +137,8 @@ func ExtractUrls(data string) []string {
re := regexp.MustCompile(`(https?://\S+)`) re := regexp.MustCompile(`(https?://\S+)`)
return re.FindAllString(data, -1) return re.FindAllString(data, -1)
} }
func ExtractImageUrls(data string) []string {
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|svg|bmp))`)
return re.FindAllString(data, -1)
}